Spaces:
Runtime error
Runtime error
# imports | |
import logging | |
import time | |
import torch | |
from transformers import GenerationConfig, pipeline | |
# Setting up logging | |
logging.basicConfig( | |
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" | |
) | |
class BatchAggregator: | |
def __init__( | |
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs | |
): | |
self.logger = logging.getLogger(__name__) | |
self.model_name = model_name | |
self.logger.info(f"Initializing aggregator with model {model_name}") | |
self.aggregator = pipeline( | |
"text2text-generation", | |
model_name, | |
device=0 if torch.cuda.is_available() else -1, | |
torch_dtype=torch.float32, | |
) | |
try: | |
self.aggregator.model = torch.compile(self.aggregator.model) | |
except Exception as e: | |
self.logger.warning(f"Could not compile model with Torch 2.0: {e}") | |
try: | |
self.aggregator.model.generation_config = GenerationConfig.from_pretrained( | |
self.model_name | |
) | |
except Exception as e: | |
self.logger.warning( | |
f"Could not load generation config, using defaults: {e}" | |
) | |
self.aggregator.model.generation_config = GenerationConfig( | |
num_beams=4, | |
early_stopping=True, | |
do_sample=False, | |
min_new_tokens=32, | |
max_new_tokens=192, | |
repetition_penalty=1.1, | |
length_penalty=1.5, | |
no_repeat_ngram_size=4, | |
encoder_no_repeat_ngram_size=5, | |
decoder_start_token_id=0, | |
eos_token_id=1, | |
pad_token_id=0, | |
) | |
if "bart" in model_name.lower(): | |
self.logger.info("Using BART model, updating generation config") | |
upd = { | |
"num_beams": 8, | |
"repetition_penalty": 1.3, | |
"length_penalty": 1.0, | |
"_from_model_config": False, | |
"max_new_tokens": 256, | |
"min_new_tokens": 32, | |
"no_repeat_ngram_size": 3, | |
"encoder_no_repeat_ngram_size": 6, | |
} | |
self.aggregator.model.generation_config.update(**upd) | |
if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1": | |
self.logger.info("Updating generation config with defaults") | |
self.update_generation_config() | |
self.logger.info(self.aggregator.model.generation_config.to_json_string()) | |
def update_generation_config(self, **kwargs): | |
self.logger.info(f"Updating generation config with {kwargs}") | |
default = GenerationConfig( | |
num_beams=4, | |
early_stopping=True, | |
do_sample=False, | |
min_new_tokens=32, | |
max_new_tokens=192, | |
repetition_penalty=1.1, | |
length_penalty=1.5, | |
no_repeat_ngram_size=4, | |
encoder_no_repeat_ngram_size=5, | |
decoder_start_token_id=0, | |
eos_token_id=1, | |
pad_token_id=0, | |
).to_dict() | |
self.aggregator.model.generation_config.update(**default) | |
def _replace_pipeline(model_name) | |
def infer_aggregate( | |
self, | |
text_list: list, | |
instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:", | |
**kwargs, | |
): | |
joined_text = "\n".join(text_list) | |
prompt = f"{instruction}\n\n{joined_text}\n" | |
if kwargs: | |
self.update_generation_config(**kwargs) | |
st = time.perf_counter() | |
self.logger.info(f"Running inference on {len(text_list)} texts") | |
result = self.aggregator( | |
prompt, | |
generation_config=self.aggregator.model.generation_config, | |
)[0]["generated_text"] | |
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s") | |
self.logger.info( | |
f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}" | |
) | |
return result | |
def count_tokens(self, text: str): | |
return ( | |
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False)) | |
if text | |
else 0 | |
) | |