document-summarization / aggregate.py
pszemraj's picture
✨ add aggregator
2b1a300
raw
history blame
4.4 kB
# imports
import logging
import time
import torch
from transformers import GenerationConfig, pipeline
# Setting up logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)
class BatchAggregator:
def __init__(
self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
):
self.logger = logging.getLogger(__name__)
self.model_name = model_name
self.logger.info(f"Initializing aggregator with model {model_name}")
self.aggregator = pipeline(
"text2text-generation",
model_name,
device=0 if torch.cuda.is_available() else -1,
torch_dtype=torch.float32,
)
try:
self.aggregator.model = torch.compile(self.aggregator.model)
except Exception as e:
self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
try:
self.aggregator.model.generation_config = GenerationConfig.from_pretrained(
self.model_name
)
except Exception as e:
self.logger.warning(
f"Could not load generation config, using defaults: {e}"
)
self.aggregator.model.generation_config = GenerationConfig(
num_beams=4,
early_stopping=True,
do_sample=False,
min_new_tokens=32,
max_new_tokens=192,
repetition_penalty=1.1,
length_penalty=1.5,
no_repeat_ngram_size=4,
encoder_no_repeat_ngram_size=5,
decoder_start_token_id=0,
eos_token_id=1,
pad_token_id=0,
)
if "bart" in model_name.lower():
self.logger.info("Using BART model, updating generation config")
upd = {
"num_beams": 8,
"repetition_penalty": 1.3,
"length_penalty": 1.0,
"_from_model_config": False,
"max_new_tokens": 256,
"min_new_tokens": 32,
"no_repeat_ngram_size": 3,
"encoder_no_repeat_ngram_size": 6,
}
self.aggregator.model.generation_config.update(**upd)
if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1":
self.logger.info("Updating generation config with defaults")
self.update_generation_config()
self.logger.info(self.aggregator.model.generation_config.to_json_string())
def update_generation_config(self, **kwargs):
self.logger.info(f"Updating generation config with {kwargs}")
default = GenerationConfig(
num_beams=4,
early_stopping=True,
do_sample=False,
min_new_tokens=32,
max_new_tokens=192,
repetition_penalty=1.1,
length_penalty=1.5,
no_repeat_ngram_size=4,
encoder_no_repeat_ngram_size=5,
decoder_start_token_id=0,
eos_token_id=1,
pad_token_id=0,
).to_dict()
self.aggregator.model.generation_config.update(**default)
def _replace_pipeline(model_name)
def infer_aggregate(
self,
text_list: list,
instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:",
**kwargs,
):
joined_text = "\n".join(text_list)
prompt = f"{instruction}\n\n{joined_text}\n"
if kwargs:
self.update_generation_config(**kwargs)
st = time.perf_counter()
self.logger.info(f"Running inference on {len(text_list)} texts")
result = self.aggregator(
prompt,
generation_config=self.aggregator.model.generation_config,
)[0]["generated_text"]
self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
self.logger.info(
f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}"
)
return result
def count_tokens(self, text: str):
return (
len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
if text
else 0
)