ysharma HF staff commited on
Commit
fb7a592
·
1 Parent(s): f98f623

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -11
app.py CHANGED
@@ -34,7 +34,15 @@ examples=[
34
  ]
35
 
36
 
37
- def predict(message, chatbot):
 
 
 
 
 
 
 
 
38
 
39
  input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
40
  for interaction in chatbot:
@@ -44,12 +52,13 @@ def predict(message, chatbot):
44
 
45
  data = {
46
  "inputs": input_prompt,
47
- "parameters": {"max_new_tokens":256,
48
- "do_sample":True,
49
- "top_p":0.6,
50
- "temperature":0.9,}
51
- }
52
-
 
53
  response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
54
 
55
  partial_message = ""
@@ -84,8 +93,16 @@ def predict(message, chatbot):
84
  continue
85
 
86
 
87
- def predict_batch(message, chatbot):
 
88
 
 
 
 
 
 
 
 
89
  input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
90
  for interaction in chatbot:
91
  input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
@@ -94,7 +111,13 @@ def predict_batch(message, chatbot):
94
 
95
  data = {
96
  "inputs": input_prompt,
97
- "parameters": {"max_new_tokens":256}
 
 
 
 
 
 
98
  }
99
 
100
  response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
@@ -114,13 +137,55 @@ def predict_batch(message, chatbot):
114
  print(f"Request failed with status code {response.status_code}")
115
 
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  # Gradio Demo
118
  with gr.Blocks() as demo:
119
 
120
  with gr.Tab("Streaming"):
121
- gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True)
122
 
123
  with gr.Tab("Batch"):
124
- gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True)
125
 
126
  demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
 
34
  ]
35
 
36
 
37
+ # Stream text
38
+ def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
39
+
40
+ if system_prompt != "":
41
+ system_message = system_prompt
42
+ temperature = float(temperature)
43
+ if temperature < 1e-2:
44
+ temperature = 1e-2
45
+ top_p = float(top_p)
46
 
47
  input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
48
  for interaction in chatbot:
 
52
 
53
  data = {
54
  "inputs": input_prompt,
55
+ "parameters": {
56
+ "max_new_tokens":max_new_tokens,
57
+ "temperature"=temperature,
58
+ "top_p"=top_p,
59
+ "repetition_penalty"=repetition_penalty,
60
+ "do_sample":True,
61
+ },
62
  response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
63
 
64
  partial_message = ""
 
93
  continue
94
 
95
 
96
+ # No Stream
97
+ def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
98
 
99
+ if system_prompt != "":
100
+ system_message = system_prompt
101
+ temperature = float(temperature)
102
+ if temperature < 1e-2:
103
+ temperature = 1e-2
104
+ top_p = float(top_p)
105
+
106
  input_prompt = f"[INST]<<SYS>>\n{system_message}\n<</SYS>>\n\n "
107
  for interaction in chatbot:
108
  input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
 
111
 
112
  data = {
113
  "inputs": input_prompt,
114
+ "parameters": {
115
+ "max_new_tokens":max_new_tokens,
116
+ "temperature"=temperature,
117
+ "top_p"=top_p,
118
+ "repetition_penalty"=repetition_penalty,
119
+ "do_sample":True,
120
+ },
121
  }
122
 
123
  response = requests.post(api_url_nostream, headers=headers, data=json.dumps(data), auth=('hf', hf_token))
 
137
  print(f"Request failed with status code {response.status_code}")
138
 
139
 
140
+
141
+ additional_inputs=[
142
+ gr.Textbox("", label="Optional system prompt"),
143
+ gr.Slider(
144
+ label="Temperature",
145
+ value=0.9,
146
+ minimum=0.0,
147
+ maximum=1.0,
148
+ step=0.05,
149
+ interactive=True,
150
+ info="Higher values produce more diverse outputs",
151
+ ),
152
+ gr.Slider(
153
+ label="Max new tokens",
154
+ value=256,
155
+ minimum=0,
156
+ maximum=4096,
157
+ step=64,
158
+ interactive=True,
159
+ info="The maximum numbers of new tokens",
160
+ ),
161
+ gr.Slider(
162
+ label="Top-p (nucleus sampling)",
163
+ value=0.6,
164
+ minimum=0.0,
165
+ maximum=1,
166
+ step=0.05,
167
+ interactive=True,
168
+ info="Higher values sample more low-probability tokens",
169
+ ),
170
+ gr.Slider(
171
+ label="Repetition penalty",
172
+ value=1.2,
173
+ minimum=1.0,
174
+ maximum=2.0,
175
+ step=0.05,
176
+ interactive=True,
177
+ info="Penalize repeated tokens",
178
+ )
179
+ ]
180
+
181
+
182
  # Gradio Demo
183
  with gr.Blocks() as demo:
184
 
185
  with gr.Tab("Streaming"):
186
+ gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
187
 
188
  with gr.Tab("Batch"):
189
+ gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
190
 
191
  demo.queue(concurrency_count=75, max_size=100).launch(debug=True)