ysharma HF staff commited on
Commit
028e7df
·
1 Parent(s): efe4a11

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -46
app.py CHANGED
@@ -39,8 +39,8 @@ examples=[
39
 
40
 
41
  # Stream text
42
- def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
43
-
44
  if system_prompt != "":
45
  input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
46
  else:
@@ -55,53 +55,24 @@ def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=
55
  input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
56
 
57
  input_prompt = input_prompt + str(message) + " [/INST] "
58
-
59
- data = {
60
- "inputs": input_prompt,
61
- "parameters": {
62
- "max_new_tokens":max_new_tokens,
63
- "temperature":temperature,
64
- "top_p":top_p,
65
- "repetition_penalty":repetition_penalty,
66
- "do_sample":True,
67
- },
68
- }
69
- response = requests.post(api_url, headers=headers, data=json.dumps(data), auth=('hf', hf_token), stream=True)
70
-
71
- partial_message = ""
72
- for line in response.iter_lines():
73
- if line: # filter out keep-alive new lines
74
- # Decode from bytes to string
75
- decoded_line = line.decode('utf-8')
76
-
77
- # Remove 'data:' prefix
78
- if decoded_line.startswith('data:'):
79
- json_line = decoded_line[5:] # Exclude the first 5 characters ('data:')
80
- else:
81
- gr.Warning(f"This line does not start with 'data:': {decoded_line}")
82
- continue
83
-
84
- # Load as JSON
85
- try:
86
- json_obj = json.loads(json_line)
87
- if 'token' in json_obj:
88
- partial_message = partial_message + json_obj['token']['text']
89
- yield partial_message
90
- elif 'error' in json_obj:
91
- yield json_obj['error'] + '. Please refresh and try again with an appropriate smaller input prompt.'
92
- else:
93
- gr.Warning(f"The key 'token' does not exist in this JSON object: {json_obj}")
94
-
95
- except json.JSONDecodeError:
96
- gr.Warning(f"This line is not valid JSON: {json_line}")
97
- continue
98
- except KeyError as e:
99
- gr.Warning(f"KeyError: {e} occurred for JSON object: {json_obj}")
100
- continue
101
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  # No Stream
104
  def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
 
 
105
  if system_prompt != "":
106
  input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
107
  else:
@@ -129,11 +100,17 @@ def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_t
129
  },
130
  }
131
 
132
- response = requests.post(api_url_nostream, headers=headers, json=data )
 
 
 
 
133
 
134
  if response.status_code == 200: # check if the request was successful
135
  try:
136
  json_obj = response.json()
 
 
137
  if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0:
138
  return json_obj[0]['generated_text']
139
  elif 'error' in json_obj[0]:
@@ -146,6 +123,7 @@ def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_t
146
  print(f"Request failed with status code {response.status_code}")
147
 
148
 
 
149
  def vote(data: gr.LikeData):
150
  if data.liked:
151
  print("You upvoted this response: " + data.value)
 
39
 
40
 
41
  # Stream text
42
+ async def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
43
+
44
  if system_prompt != "":
45
  input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
46
  else:
 
55
  input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s>[INST] "
56
 
57
  input_prompt = input_prompt + str(message) + " [/INST] "
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
+ partial_message = ""
60
+ async for token in await client.text_generation(prompt=input_prompt,
61
+ max_new_tokens=max_new_tokens,
62
+ stream=True,
63
+ best_of=1,
64
+ temperature=temperature,
65
+ top_p=top_p,
66
+ do_sample=True,
67
+ repetition_penalty=repetition_penalty):
68
+ partial_message = partial_message + token
69
+ yield partial_message
70
+
71
 
72
  # No Stream
73
  def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0,):
74
+ print(f"message - {message}")
75
+ print(f"chatbot - {chatbot}")
76
  if system_prompt != "":
77
  input_prompt = f"<s>[INST] <<SYS>>\n{system_prompt}\n<</SYS>>\n\n "
78
  else:
 
100
  },
101
  }
102
 
103
+ response = requests.post(api_url_nostream, headers=headers, json=data ) #auth=('hf', hf_token)) data=json.dumps(data),
104
+ print(f"response - {response}")
105
+ print(f"response.status_code - {response.status_code}")
106
+ print(f"response.text - {response.text}")
107
+ print(f"type(response.text) - {type(response.text)}")
108
 
109
  if response.status_code == 200: # check if the request was successful
110
  try:
111
  json_obj = response.json()
112
+ print(f"type(response.json) - {type(json_obj)}")
113
+ print(f"response.json - {json_obj}")
114
  if 'generated_text' in json_obj[0] and len(json_obj[0]['generated_text']) > 0:
115
  return json_obj[0]['generated_text']
116
  elif 'error' in json_obj[0]:
 
123
  print(f"Request failed with status code {response.status_code}")
124
 
125
 
126
+
127
  def vote(data: gr.LikeData):
128
  if data.liked:
129
  print("You upvoted this response: " + data.value)