Spaces:
Running
Running
from typing import Any, List, Mapping, Optional | |
from langchain_core.language_models.llms import LLM | |
from langchain_core.callbacks.manager import CallbackManagerForLLMRun | |
from typing import Literal | |
import requests | |
class CustomLLM(LLM): | |
repo_id : str | |
api_token : str | |
model_type: Literal["text2text-generation", "text-generation"] | |
max_new_tokens: int = None | |
temperature: float = 0.001 | |
timeout: float = None | |
top_p: float = None | |
top_k : int = None | |
repetition_penalty : float = None | |
stop : List[str] = [] | |
def _llm_type(self) -> str: | |
return "custom" | |
def _call( | |
self, | |
prompt: str, | |
stop: Optional[List[str]] = None, | |
run_manager: Optional[CallbackManagerForLLMRun] = None, | |
**kwargs: Any, | |
) -> str: | |
headers = {"Authorization": f"Bearer {self.api_token}"} | |
API_URL = f"https://api-inference.huggingface.co/models/{self.repo_id}" | |
parameters_dict = { | |
'max_new_tokens': self.max_new_tokens, | |
'temperature': self.temperature, | |
'timeout': self.timeout, | |
'top_p': self.top_p, | |
'top_k': self.top_k, | |
'repetition_penalty': self.repetition_penalty, | |
'stop':self.stop | |
} | |
if self.model_type == 'text-generation': | |
parameters_dict["return_full_text"]=False | |
data = {"inputs": prompt, "parameters":parameters_dict, "options":{"wait_for_model":True}} | |
data = requests.post(API_URL, headers=headers, json=data).json() | |
try: | |
return data[0]['generated_text'] | |
except: | |
return data | |
def _identifying_params(self) -> Mapping[str, Any]: | |
"""Get the identifying parameters.""" | |
return { | |
'repo_id': self.repo_id, | |
'model_type':self.model_type, | |
'stop_sequences':self.stop, | |
'max_new_tokens': self.max_new_tokens, | |
'temperature': self.temperature, | |
'timeout': self.timeout, | |
'top_p': self.top_p, | |
'top_k': self.top_k, | |
'repetition_penalty': self.repetition_penalty | |
} |