Spaces:
Runtime error
Runtime error
import pprint as pp | |
import logging | |
import time | |
import torch | |
from transformers import GenerationConfig, pipeline | |
from utils import compare_model_size | |
# Setting up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
class BatchAggregator: | |
CONFIGURED_MODELS = [ | |
"pszemraj/bart-large-mnli-dolly_hhrlhf-v1" | |
] # TODO: Add models here | |
DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:" | |
GENERIC_CONFIG = GenerationConfig( | |
num_beams=8, | |
early_stopping=True, | |
do_sample=False, | |
min_new_tokens=32, | |
max_new_tokens=256, | |
repetition_penalty=1.1, | |
length_penalty=1.4, | |
no_repeat_ngram_size=4, | |
encoder_no_repeat_ngram_size=5, | |
) | |
def __init__( | |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs | |
): | |
self.device = None | |
self.is_compiled = False | |
self.logger = logging.getLogger(__name__) | |
self.init_model(model_name) | |
def init_model(self, model_name: str) -> None: | |
""" | |
Initialize the model. | |
:param model_name: The name of the model to use. | |
""" | |
# Free up memory | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
self.logger.info(f"Setting model to {model_name}") | |
self.model_name = model_name | |
self.aggregator = self._create_pipeline(model_name) | |
self._configure_model() | |
# update the generation config with the specific tokenizer | |
tokenizer_params = { | |
"decoder_start_token_id": 0 | |
if "t5" in model_name.lower() | |
else self.aggregator.tokenizer.eos_token_id, | |
"eos_token_id": 1 | |
if "t5" in model_name.lower() | |
else self.aggregator.tokenizer.eos_token_id, | |
"pad_token_id": 0 | |
if "t5" in model_name.lower() | |
else self.aggregator.tokenizer.pad_token_id, | |
} | |
self.update_generation_config(**tokenizer_params) | |
def _create_pipeline( | |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" | |
) -> pipeline: | |
""" | |
_create_pipeline creates a pipeline for the model. | |
:param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1" | |
:return pipeline: the pipeline for the model | |
:raises Exception: if the pipeline cannot be created | |
""" | |
self.device = 0 if torch.cuda.is_available() else -1 | |
try: | |
self.logger.info( | |
f"Creating pipeline with model {model_name} on device {self.device}" | |
) | |
return pipeline( | |
"text2text-generation", | |
model_name, | |
device=self.device, | |
torch_dtype=torch.float32, | |
) | |
except Exception as e: | |
self.logger.error(f"Failed to create pipeline: {e}") | |
raise | |
def _configure_model(self): | |
""" | |
Configure the model for generation. | |
""" | |
try: | |
self.aggregator.model = torch.compile(self.aggregator.model) | |
self.is_compiled = True | |
except Exception as e: | |
self.logger.warning(f"Could not compile model with Torch 2.0: {e}") | |
if self.model_name not in self.CONFIGURED_MODELS: | |
self.logger.info("Setting generation config to general defaults") | |
self._set_default_generation_config() | |
else: | |
try: | |
self.logger.info("Loading generation config from hub") | |
self.aggregator.model.generation_config = ( | |
GenerationConfig.from_pretrained(self.model_name) | |
) | |
except Exception as e: | |
self.logger.warning( | |
f"Could not load generation config, using defaults: {e}" | |
) | |
self._set_default_generation_config() | |
self.logger.info(self.aggregator.model.generation_config.to_json_string()) | |
def _set_default_generation_config(self): | |
""" | |
Set the default generation configuration for the model. | |
""" | |
self.aggregator.model.generation_config = self.GENERIC_CONFIG | |
if "bart" in self.model_name.lower(): | |
self.logger.info("Using BART model, updating generation config") | |
upd = { | |
"num_beams": 8, | |
"repetition_penalty": 1.3, | |
"length_penalty": 1.0, | |
"_from_model_config": False, | |
"max_new_tokens": 256, | |
"min_new_tokens": 32, | |
"no_repeat_ngram_size": 3, | |
"encoder_no_repeat_ngram_size": 6, | |
} # TODO: clean up | |
self.aggregator.model.generation_config.update(**upd) | |
if ( | |
"large" | |
or "xl" in self.model_name.lower() | |
or compare_model_size(self.model_name, 500) | |
): | |
upd = {"num_beams": 4} | |
self.update_generation_config(**upd) | |
def update_generation_config(self, **kwargs): | |
""" | |
Update the generation configuration with the specified parameters. | |
Args: | |
**kwargs: The parameters to update in the generation configuration. | |
""" | |
self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}") | |
self.aggregator.model.generation_config.update(**kwargs) | |
def update_loglevel(self, level: str = "INFO"): | |
""" | |
Update the log level. | |
Args: | |
level (str): The log level to set. Defaults to "INFO". | |
""" | |
self.logger.setLevel(level) | |
def infer_aggregate( | |
self, | |
text_list: list, | |
instruction: str = DEFAULT_INSTRUCTION, | |
**kwargs, | |
) -> str: | |
f""" | |
Generate a summary of the specified texts. | |
Args: | |
text_list (list): The texts to summarize. | |
instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}. | |
**kwargs: Additional parameters to update in the generation configuration. | |
Returns: | |
The generated summary. | |
""" | |
joined_text = "\n".join(text_list) | |
prompt = f"{instruction}\n\n{joined_text}\n" | |
if kwargs: | |
self.update_generation_config(**kwargs) | |
st = time.perf_counter() | |
self.logger.info(f"inference on {len(text_list)} texts ...") | |
result = self.aggregator( | |
prompt, | |
generation_config=self.aggregator.model.generation_config, | |
)[0]["generated_text"] | |
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") | |
self.logger.info( | |
f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}" | |
) | |
return result | |
def count_tokens(self, text: str) -> int: | |
"""count the number of tokens in a text""" | |
return ( | |
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) | |
if text | |
else 0 | |
) | |