File size: 4,402 Bytes
2b1a300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
# imports
import logging
import time

import torch
from transformers import GenerationConfig, pipeline

# Setting up logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


class BatchAggregator:
    def __init__(
        self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
    ):
        self.logger = logging.getLogger(__name__)
        self.model_name = model_name
        self.logger.info(f"Initializing aggregator with model {model_name}")
        self.aggregator = pipeline(
            "text2text-generation",
            model_name,
            device=0 if torch.cuda.is_available() else -1,
            torch_dtype=torch.float32,
        )

        try:
            self.aggregator.model = torch.compile(self.aggregator.model)
        except Exception as e:
            self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
        try:
            self.aggregator.model.generation_config = GenerationConfig.from_pretrained(
                self.model_name
            )
        except Exception as e:
            self.logger.warning(
                f"Could not load generation config, using defaults: {e}"
            )
            self.aggregator.model.generation_config = GenerationConfig(
                num_beams=4,
                early_stopping=True,
                do_sample=False,
                min_new_tokens=32,
                max_new_tokens=192,
                repetition_penalty=1.1,
                length_penalty=1.5,
                no_repeat_ngram_size=4,
                encoder_no_repeat_ngram_size=5,
                decoder_start_token_id=0,
                eos_token_id=1,
                pad_token_id=0,
            )

            if "bart" in model_name.lower():
                self.logger.info("Using BART model, updating generation config")
                upd = {
                    "num_beams": 8,
                    "repetition_penalty": 1.3,
                    "length_penalty": 1.0,
                    "_from_model_config": False,
                    "max_new_tokens": 256,
                    "min_new_tokens": 32,
                    "no_repeat_ngram_size": 3,
                    "encoder_no_repeat_ngram_size": 6,
                }
                self.aggregator.model.generation_config.update(**upd)
        if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1":
            self.logger.info("Updating generation config with defaults")
            self.update_generation_config()
        self.logger.info(self.aggregator.model.generation_config.to_json_string())

    def update_generation_config(self, **kwargs):
        self.logger.info(f"Updating generation config with {kwargs}")
        default = GenerationConfig(
            num_beams=4,
            early_stopping=True,
            do_sample=False,
            min_new_tokens=32,
            max_new_tokens=192,
            repetition_penalty=1.1,
            length_penalty=1.5,
            no_repeat_ngram_size=4,
            encoder_no_repeat_ngram_size=5,
            decoder_start_token_id=0,
            eos_token_id=1,
            pad_token_id=0,
        ).to_dict()
        self.aggregator.model.generation_config.update(**default)
    def _replace_pipeline(model_name)
    def infer_aggregate(
        self,
        text_list: list,
        instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:",
        **kwargs,
    ):
        joined_text = "\n".join(text_list)
        prompt = f"{instruction}\n\n{joined_text}\n"
        if kwargs:
            self.update_generation_config(**kwargs)
        st = time.perf_counter()
        self.logger.info(f"Running inference on {len(text_list)} texts")
        result = self.aggregator(
            prompt,
            generation_config=self.aggregator.model.generation_config,
        )[0]["generated_text"]
        self.logger.info(f"Done. runtime:\t{round(time.perf_counter() - st, 2)}s")
        self.logger.info(
            f"Input tokens:\t{self.count_tokens(prompt)}. Output tokens:\t{self.count_tokens(result)}"
        )
        return result

    def count_tokens(self, text: str):
        return (
            len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
            if text
            else 0
        )