lukestanley commited on
Commit
64c61f3
·
1 Parent(s): 41ac6cc

Add Anthropic Opus support

Browse files
Files changed (1) hide show
  1. utils.py +55 -4
utils.py CHANGED
@@ -24,9 +24,9 @@ from huggingface_hub import hf_hub_download
24
 
25
  URL = "http://localhost:5834/v1/chat/completions"
26
  in_memory_llm = None
27
- worker_options = ["runpod", "http", "in_memory", "mistral"]
28
 
29
- LLM_WORKER = env.get("LLM_WORKER", "mistral")
30
  if LLM_WORKER not in worker_options:
31
  raise ValueError(f"Invalid worker: {LLM_WORKER}")
32
  N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
@@ -250,11 +250,62 @@ def llm_stream_mistral_api(prompt: str, pydantic_model_class=None, attempts=0) -
250
  return json.loads(output)
251
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  def query_ai_prompt(prompt, replacements, model_class):
255
  prompt = replace_text(prompt, replacements)
256
- if LLM_WORKER == "mistral":
257
- result = llm_stream_mistral_api(prompt, model_class)
258
  if LLM_WORKER == "mistral":
259
  result = llm_stream_mistral_api(prompt, model_class)
260
  if LLM_WORKER == "runpod":
 
24
 
25
  URL = "http://localhost:5834/v1/chat/completions"
26
  in_memory_llm = None
27
+ worker_options = ["runpod", "http", "in_memory", "mistral", "anthropic"]
28
 
29
+ LLM_WORKER = env.get("LLM_WORKER", "anthropic")
30
  if LLM_WORKER not in worker_options:
31
  raise ValueError(f"Invalid worker: {LLM_WORKER}")
32
  N_GPU_LAYERS = int(env.get("N_GPU_LAYERS", -1)) # Default to -1, use all layers if available
 
250
  return json.loads(output)
251
 
252
 
253
+ def send_anthropic_request(prompt: str):
254
+ api_key = env.get("ANTHROPIC_API_KEY")
255
+ if not api_key:
256
+ print("API key not found. Please set the ANTHROPIC_API_KEY environment variable.")
257
+ return
258
+
259
+ headers = {
260
+ 'x-api-key': api_key,
261
+ 'anthropic-version': '2023-06-01',
262
+ 'Content-Type': 'application/json',
263
+ }
264
+
265
+ data = {
266
+ "model": "claude-3-opus-20240229",
267
+ "max_tokens": 1024,
268
+ "messages": [{"role": "user", "content": prompt}]
269
+ }
270
+
271
+ response = requests.post('https://api.anthropic.com/v1/messages', headers=headers, data=json.dumps(data))
272
+ if response.status_code != 200:
273
+ print(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}")
274
+ raise ValueError(f"Unexpected Anthropic API status code: {response.status_code} with body: {response.text}")
275
+ j = response.json()
276
+
277
+ text = j['content'][0]["text"]
278
+ print(text)
279
+ return text
280
+
281
+ def llm_anthropic_api(prompt: str, pydantic_model_class=None, attempts=0) -> Union[str, Dict[str, Any]]:
282
+ # With no streaming or rate limits, we use the Anthropic API, we have string input and output from send_anthropic_request,
283
+ # but we need to convert it to JSON for the pydantic model class like the other APIs.
284
+ output = send_anthropic_request(prompt)
285
+ if pydantic_model_class:
286
+ try:
287
+ parsed_result = pydantic_model_class.model_validate_json(output)
288
+ print(parsed_result)
289
+ # This will raise an exception if the model is invalid.
290
+ return json.loads(output)
291
+ except Exception as e:
292
+ print(f"Error validating pydantic model: {e}")
293
+ # Let's retry by calling ourselves again if attempts < 3
294
+ if attempts == 0:
295
+ # We modify the prompt to remind it to output JSON in the required format
296
+ prompt = f"{prompt} You must output the JSON in the required format only, with no remarks or prefacing remarks - JUST JSON!"
297
+ if attempts < 3:
298
+ attempts += 1
299
+ print(f"Retrying Anthropic API call, attempt {attempts}")
300
+ return llm_anthropic_api(prompt, pydantic_model_class, attempts)
301
+ else:
302
+ print("No pydantic model class provided, returning without class validation")
303
+ return json.loads(output)
304
 
305
  def query_ai_prompt(prompt, replacements, model_class):
306
  prompt = replace_text(prompt, replacements)
307
+ if LLM_WORKER == "anthropic":
308
+ result = llm_anthropic_api(prompt, model_class)
309
  if LLM_WORKER == "mistral":
310
  result = llm_stream_mistral_api(prompt, model_class)
311
  if LLM_WORKER == "runpod":