jonathanjordan21 commited on
Commit
e944c71
·
verified ·
1 Parent(s): aeabb44

Create custom_llm.py

Browse files
Files changed (1) hide show
  1. custom_llm.py +71 -0
custom_llm.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Mapping, Optional
2
+ from langchain_core.language_models.llms import LLM
3
+ from langchain_core.callbacks.manager import CallbackManagerForLLMRun
4
+
5
+ from typing import Literal
6
+
7
+ import requests
8
+
9
+
10
+ class CustomLLM(LLM):
11
+ repo_id : str
12
+ api_token : str
13
+ model_type: Literal["text2text-generation", "text-generation"]
14
+ max_new_tokens: int = None
15
+ temperature: float = 0.001
16
+ timeout: float = None
17
+ top_p: float = None
18
+ top_k : int = None
19
+ repetition_penalty : float = None
20
+ stop : List[str] = []
21
+
22
+
23
+ @property
24
+ def _llm_type(self) -> str:
25
+ return "custom"
26
+
27
+ def _call(
28
+ self,
29
+ prompt: str,
30
+ stop: Optional[List[str]] = None,
31
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
32
+ **kwargs: Any,
33
+ ) -> str:
34
+
35
+ headers = {"Authorization": f"Bearer {self.api_token}"}
36
+ API_URL = f"https://api-inference.huggingface.co/models/{self.repo_id}"
37
+
38
+ parameters_dict = {
39
+ 'max_new_tokens': self.max_new_tokens,
40
+ 'temperature': self.temperature,
41
+ 'timeout': self.timeout,
42
+ 'top_p': self.top_p,
43
+ 'top_k': self.top_k,
44
+ 'repetition_penalty': self.repetition_penalty,
45
+ 'stop':self.stop
46
+ }
47
+
48
+ if self.model_type == 'text-generation':
49
+ parameters_dict["return_full_text"]=False
50
+
51
+ data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}}
52
+ data = requests.post(API_URL, headers=headers, json=data).json()
53
+ try:
54
+ return data[0]['generated_text']
55
+ except:
56
+ return data
57
+
58
+ @property
59
+ def _identifying_params(self) -> Mapping[str, Any]:
60
+ """Get the identifying parameters."""
61
+ return {
62
+ 'repo_id': self.repo_id,
63
+ 'model_type':self.model_type,
64
+ 'stop_sequences':self.stop,
65
+ 'max_new_tokens': self.max_new_tokens,
66
+ 'temperature': self.temperature,
67
+ 'timeout': self.timeout,
68
+ 'top_p': self.top_p,
69
+ 'top_k': self.top_k,
70
+ 'repetition_penalty': self.repetition_penalty
71
+ }