cetusian commited on
Commit
0a8cafa
·
verified ·
1 Parent(s): d82f204

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ import openai
4
+ import torch
5
+
6
+ # Load Llama model (GPU-optimized)
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+ llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
9
+ llama_model = AutoModelForCausalLM.from_pretrained(
10
+ "meta-llama/Llama-2-7b-chat-hf",
11
+ device_map="auto"
12
+ )
13
+
14
+ # OpenAI GPT Model API Key (Replace with your API key)
15
+ openai.api_key = "YOUR_OPENAI_API_KEY"
16
+
17
+ # Function to query Llama
18
+ def query_llama(prompt):
19
+ inputs = llama_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device)
20
+ outputs = llama_model.generate(inputs.input_ids, max_length=150)
21
+ response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
22
+ return response
23
+
24
+ # Function to query GPT
25
+ def query_gpt(prompt):
26
+ response = openai.Completion.create(
27
+ engine="text-davinci-003",
28
+ prompt=prompt,
29
+ max_tokens=150
30
+ )
31
+ return response['choices'][0]['text'].strip()
32
+
33
+ # Function to compare models
34
+ def compare_models(prompt, models):
35
+ responses = {}
36
+ if "Llama" in models:
37
+ responses["Llama"] = query_llama(prompt)
38
+ if "GPT" in models:
39
+ responses["GPT"] = query_gpt(prompt)
40
+ return responses
41
+
42
+ # Gradio Interface
43
+ def gradio_app():
44
+ with gr.Blocks() as app:
45
+ gr.Markdown("# AI Model Comparison Tool 🚀")
46
+ with gr.Row():
47
+ prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Ask something...")
48
+ with gr.Row():
49
+ model_selector = gr.CheckboxGroup(
50
+ ["Llama", "GPT"],
51
+ label="Select Models to Compare",
52
+ value=["Llama", "GPT"]
53
+ )
54
+ with gr.Row():
55
+ output_boxes = gr.JSON(label="Model Responses")
56
+ with gr.Row():
57
+ compare_button = gr.Button("Compare Models")
58
+ compare_button.click(compare_models, inputs=[prompt_input, model_selector], outputs=[output_boxes])
59
+ return app
60
+
61
+ if __name__ == "__main__":
62
+ gradio_app().launch()