cetusian commited on
Commit
7841db2
·
verified ·
1 Parent(s): 0a8cafa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -55
app.py CHANGED
@@ -1,62 +1,85 @@
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
- import openai
4
- import torch
5
-
6
- # Load Llama model (GPU-optimized)
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
- llama_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
9
- llama_model = AutoModelForCausalLM.from_pretrained(
10
- "meta-llama/Llama-2-7b-chat-hf",
11
- device_map="auto"
12
- )
13
-
14
- # OpenAI GPT Model API Key (Replace with your API key)
15
- openai.api_key = "YOUR_OPENAI_API_KEY"
16
-
17
- # Function to query Llama
18
- def query_llama(prompt):
19
- inputs = llama_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device)
20
- outputs = llama_model.generate(inputs.input_ids, max_length=150)
21
- response = llama_tokenizer.decode(outputs[0], skip_special_tokens=True)
22
- return response
23
-
24
- # Function to query GPT
25
- def query_gpt(prompt):
26
- response = openai.Completion.create(
27
- engine="text-davinci-003",
28
- prompt=prompt,
29
- max_tokens=150
30
- )
31
- return response['choices'][0]['text'].strip()
32
 
33
- # Function to compare models
34
- def compare_models(prompt, models):
35
  responses = {}
36
- if "Llama" in models:
37
- responses["Llama"] = query_llama(prompt)
38
- if "GPT" in models:
39
- responses["GPT"] = query_gpt(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  return responses
41
 
42
- # Gradio Interface
43
- def gradio_app():
44
- with gr.Blocks() as app:
45
- gr.Markdown("# AI Model Comparison Tool 🚀")
46
- with gr.Row():
47
- prompt_input = gr.Textbox(label="Enter your prompt", placeholder="Ask something...")
48
- with gr.Row():
49
- model_selector = gr.CheckboxGroup(
50
- ["Llama", "GPT"],
51
- label="Select Models to Compare",
52
- value=["Llama", "GPT"]
53
- )
54
- with gr.Row():
55
- output_boxes = gr.JSON(label="Model Responses")
56
- with gr.Row():
57
- compare_button = gr.Button("Compare Models")
58
- compare_button.click(compare_models, inputs=[prompt_input, model_selector], outputs=[output_boxes])
59
- return app
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  if __name__ == "__main__":
62
- gradio_app().launch()
 
 
1
+ import os
2
  import gradio as gr
3
+ from huggingface_hub import login
4
+ from huggingface_hub import InferenceClient
5
+ import spaces
6
+
7
+ # Authenticate with Hugging Face API
8
+ api_key = os.getenv("LLAMA")
9
+ login(api_key)
10
+
11
+ # Initialize clients for different models
12
+ llama_client = InferenceClient("meta-llama/Llama-3.1-70B-Instruct")
13
+ gpt_client = InferenceClient("openai/gpt-4") # Example: Replace with your OpenAI GPT model
14
+
15
+ # Define the response function
16
+ @spaces.GPU
17
+ def respond(
18
+ message,
19
+ history: list[dict],
20
+ system_message,
21
+ max_tokens,
22
+ temperature,
23
+ top_p,
24
+ selected_models,
25
+ ):
26
+ # Prepare input messages
27
+ messages = [{"role": "system", "content": system_message}] + history
28
+ messages.append({"role": "user", "content": message})
 
 
 
 
29
 
30
+ # Collect responses from selected models
 
31
  responses = {}
32
+
33
+ if "Llama" in selected_models:
34
+ llama_response = ""
35
+ for token in llama_client.chat_completion(
36
+ messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p
37
+ ):
38
+ delta = token.choices[0].delta.content
39
+ llama_response += delta
40
+ responses["Llama"] = llama_response
41
+
42
+ if "GPT" in selected_models:
43
+ gpt_response = ""
44
+ for token in gpt_client.chat_completion(
45
+ messages, max_tokens=max_tokens, stream=True, temperature=temperature, top_p=top_p
46
+ ):
47
+ delta = token.choices[0].delta.content
48
+ gpt_response += delta
49
+ responses["GPT"] = gpt_response
50
+
51
  return responses
52
 
53
+ # Build the Gradio app
54
+ def create_demo():
55
+ return gr.Blocks().add(
56
+ gr.Markdown("# AI Model Comparison Tool 🌟"),
57
+ gr.ChatInterface(
58
+ respond,
59
+ type="messages",
60
+ additional_inputs=[
61
+ gr.Textbox(
62
+ value="You are a helpful assistant providing answers for technical and customer support queries.",
63
+ label="System message"
64
+ ),
65
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
66
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
67
+ gr.Slider(
68
+ minimum=0.1,
69
+ maximum=1.0,
70
+ value=0.95,
71
+ step=0.05,
72
+ label="Top-p (nucleus sampling)"
73
+ ),
74
+ gr.CheckboxGroup(
75
+ ["Llama", "GPT"],
76
+ label="Select models to compare",
77
+ value=["Llama"]
78
+ ),
79
+ ],
80
+ ),
81
+ )
82
 
83
  if __name__ == "__main__":
84
+ demo = create_demo()
85
+ demo.launch()