pszemraj commited on
Commit
2b1a300
·
1 Parent(s): c2d711c

✨ add aggregator

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (1) hide show
  1. aggregate.py +118 -0
aggregate.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # imports
2
+ import logging
3
+ import time
4
+
5
+ import torch
6
+ from transformers import GenerationConfig, pipeline
7
+
8
+ # Setting up logging
9
+ logging.basicConfig(
10
+ level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
11
+ )
12
+
13
+
14
+ class BatchAggregator:
15
+ def __init__(
16
+ self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
17
+ ):
18
+ self.logger = logging.getLogger(__name__)
19
+ self.model_name = model_name
20
+ self.logger.info(f"Initializing aggregator with model {model_name}")
21
+ self.aggregator = pipeline(
22
+ "text2text-generation",
23
+ model_name,
24
+ device=0 if torch.cuda.is_available() else -1,
25
+ torch_dtype=torch.float32,
26
+ )
27
+
28
+ try:
29
+ self.aggregator.model = torch.compile(self.aggregator.model)
30
+ except Exception as e:
31
+ self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
32
+ try:
33
+ self.aggregator.model.generation_config = GenerationConfig.from_pretrained(
34
+ self.model_name
35
+ )
36
+ except Exception as e:
37
+ self.logger.warning(
38
+ f"Could not load generation config, using defaults: {e}"
39
+ )
40
+ self.aggregator.model.generation_config = GenerationConfig(
41
+ num_beams=4,
42
+ early_stopping=True,
43
+ do_sample=False,
44
+ min_new_tokens=32,
45
+ max_new_tokens=192,
46
+ repetition_penalty=1.1,
47
+ length_penalty=1.5,
48
+ no_repeat_ngram_size=4,
49
+ encoder_no_repeat_ngram_size=5,
50
+ decoder_start_token_id=0,
51
+ eos_token_id=1,
52
+ pad_token_id=0,
53
+ )
54
+
55
+ if "bart" in model_name.lower():
56
+ self.logger.info("Using BART model, updating generation config")
57
+ upd = {
58
+ "num_beams": 8,
59
+ "repetition_penalty": 1.3,
60
+ "length_penalty": 1.0,
61
+ "_from_model_config": False,
62
+ "max_new_tokens": 256,
63
+ "min_new_tokens": 32,
64
+ "no_repeat_ngram_size": 3,
65
+ "encoder_no_repeat_ngram_size": 6,
66
+ }
67
+ self.aggregator.model.generation_config.update(**upd)
68
+ if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1":
69
+ self.logger.info("Updating generation config with defaults")
70
+ self.update_generation_config()
71
+ self.logger.info(self.aggregator.model.generation_config.to_json_string())
72
+
73
+ def update_generation_config(self, **kwargs):
74
+ self.logger.info(f"Updating generation config with {kwargs}")
75
+ default = GenerationConfig(
76
+ num_beams=4,
77
+ early_stopping=True,
78
+ do_sample=False,
79
+ min_new_tokens=32,
80
+ max_new_tokens=192,
81
+ repetition_penalty=1.1,
82
+ length_penalty=1.5,
83
+ no_repeat_ngram_size=4,
84
+ encoder_no_repeat_ngram_size=5,
85
+ decoder_start_token_id=0,
86
+ eos_token_id=1,
87
+ pad_token_id=0,
88
+ ).to_dict()
89
+ self.aggregator.model.generation_config.update(**default)
90
+ def _replace_pipeline(model_name)
91
+ def infer_aggregate(
92
+ self,
93
+ text_list: list,
94
+ instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:",
95
+ **kwargs,
96
+ ):
97
+ joined_text = "\n".join(text_list)
98
+ prompt = f"{instruction}\n\n{joined_text}\n"
99
+ if kwargs:
100
+ self.update_generation_config(**kwargs)
101
+ st = time.perf_counter()
102
+ self.logger.info(f"Running inference on {len(text_list)} texts")
103
+ result = self.aggregator(
104
+ prompt,
105
+ generation_config=self.aggregator.model.generation_config,
106
+ )[0]["generated_text"]
107
+ self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
108
+ self.logger.info(
109
+ f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}"
110
+ )
111
+ return result
112
+
113
+ def count_tokens(self, text: str):
114
+ return (
115
+ len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
116
+ if text
117
+ else 0
118
+ )