from typing import Dict, Any, List from transformers import pipeline import torch class EndpointHandler: def __init__(self, path=""): # load model and processor from path self.pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto") def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, str]]]: """ Args: data (:dict:): The payload with the text prompt and generation parameters. """ # process input inputs = data.pop("inputs", data) parameters = data.pop("parameters", None) if isinstance(inputs, list) and isinstance(inputs[0], list) or isinstance(inputs[0], dict): if isinstance(inputs[0], dict): inputs = [inputs] messages = inputs else: if isinstance(inputs, str): messages = [[ { "role": "system", "content": "You are a helpful AI assistant", }, {"role": "user", "content": inputs}, ]] else: messages = [[ { "role": "system", "content": "You are a helpful AI assistant", }, {"role": "user", "content": input}, ] for input in inputs] prompts = [] for message in messages: prompts += [self.pipe.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)] # pass inputs with all kwargs in data if parameters is not None: # Make endpoint compatible with hf client library exclude_list = ["stop", "watermark", "details", "decoder_input_details"] parameters = {name: val for name, val in parameters.items() if name not in exclude_list} outputs = self.pipe( prompts, **parameters) else: outputs = self.pipe( prompts, max_new_tokens=32, ) return [{"generated_text": outputs}]