pszemraj commited on
Commit
e414859
·
1 Parent(s): 2d980d5

✨ mwe working aggregation

Browse files

Signed-off-by: peter szemraj <[email protected]>

Files changed (2) hide show
  1. aggregate.py +158 -67
  2. app.py +91 -10
aggregate.py CHANGED
@@ -1,10 +1,12 @@
1
- # imports
2
  import logging
3
  import time
4
 
5
  import torch
6
  from transformers import GenerationConfig, pipeline
7
 
 
 
8
  # Setting up logging
9
  logging.basicConfig(
10
  level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
@@ -12,94 +14,182 @@ logging.basicConfig(
12
 
13
 
14
  class BatchAggregator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def __init__(
16
  self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
17
  ):
 
 
18
  self.logger = logging.getLogger(__name__)
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  self.model_name = model_name
20
- self.logger.info(f"Initializing aggregator with model {model_name}")
21
- self.aggregator = pipeline(
22
- "text2text-generation",
23
- model_name,
24
- device=0 if torch.cuda.is_available() else -1,
25
- torch_dtype=torch.float32,
26
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
 
 
 
28
  try:
29
- self.aggregator.model = torch.compile(self.aggregator.model)
 
 
 
 
 
 
 
 
30
  except Exception as e:
31
- self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
 
 
 
 
 
 
32
  try:
33
- self.aggregator.model.generation_config = GenerationConfig.from_pretrained(
34
- self.model_name
35
- )
36
  except Exception as e:
37
- self.logger.warning(
38
- f"Could not load generation config, using defaults: {e}"
39
- )
40
- self.aggregator.model.generation_config = GenerationConfig(
41
- num_beams=4,
42
- early_stopping=True,
43
- do_sample=False,
44
- min_new_tokens=32,
45
- max_new_tokens=192,
46
- repetition_penalty=1.1,
47
- length_penalty=1.5,
48
- no_repeat_ngram_size=4,
49
- encoder_no_repeat_ngram_size=5,
50
- decoder_start_token_id=0,
51
- eos_token_id=1,
52
- pad_token_id=0,
53
- )
54
 
55
- if "bart" in model_name.lower():
56
- self.logger.info("Using BART model, updating generation config")
57
- upd = {
58
- "num_beams": 8,
59
- "repetition_penalty": 1.3,
60
- "length_penalty": 1.0,
61
- "_from_model_config": False,
62
- "max_new_tokens": 256,
63
- "min_new_tokens": 32,
64
- "no_repeat_ngram_size": 3,
65
- "encoder_no_repeat_ngram_size": 6,
66
- }
67
- self.aggregator.model.generation_config.update(**upd)
68
- if self.model_name != "pszemraj/bart-large-mnli-dolly_hhrlhf-v1":
69
- self.logger.info("Updating generation config with defaults")
70
- self.update_generation_config()
71
  self.logger.info(self.aggregator.model.generation_config.to_json_string())
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def update_generation_config(self, **kwargs):
74
- self.logger.info(f"Updating generation config with {kwargs}")
75
- default = GenerationConfig(
76
- num_beams=4,
77
- early_stopping=True,
78
- do_sample=False,
79
- min_new_tokens=32,
80
- max_new_tokens=192,
81
- repetition_penalty=1.1,
82
- length_penalty=1.5,
83
- no_repeat_ngram_size=4,
84
- encoder_no_repeat_ngram_size=5,
85
- decoder_start_token_id=0,
86
- eos_token_id=1,
87
- pad_token_id=0,
88
- ).to_dict()
89
- self.aggregator.model.generation_config.update(**default)
90
- def _replace_pipeline(model_name)
 
 
91
  def infer_aggregate(
92
  self,
93
  text_list: list,
94
- instruction: str = "Write a comprehensive yet concise summary in paragraph form that pulls together the main points of the following text:",
95
  **kwargs,
96
- ):
 
 
 
 
 
 
 
 
 
 
 
97
  joined_text = "\n".join(text_list)
98
  prompt = f"{instruction}\n\n{joined_text}\n"
99
  if kwargs:
100
  self.update_generation_config(**kwargs)
101
  st = time.perf_counter()
102
- self.logger.info(f"Running inference on {len(text_list)} texts")
103
  result = self.aggregator(
104
  prompt,
105
  generation_config=self.aggregator.model.generation_config,
@@ -110,7 +200,8 @@ class BatchAggregator:
110
  )
111
  return result
112
 
113
- def count_tokens(self, text: str):
 
114
  return (
115
  len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
116
  if text
 
1
+ import pprint as pp
2
  import logging
3
  import time
4
 
5
  import torch
6
  from transformers import GenerationConfig, pipeline
7
 
8
+ from utils import compare_model_size
9
+
10
  # Setting up logging
11
  logging.basicConfig(
12
  level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
 
14
 
15
 
16
  class BatchAggregator:
17
+ CONFIGURED_MODELS = [
18
+ "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
19
+ ] # TODO: Add models here
20
+ DEFAULT_INSTRUCTION = "Write a comprehensive yet concise summary that pulls together the main points of the following text:"
21
+ GENERIC_CONFIG = GenerationConfig(
22
+ num_beams=8,
23
+ early_stopping=True,
24
+ do_sample=False,
25
+ min_new_tokens=32,
26
+ max_new_tokens=256,
27
+ repetition_penalty=1.1,
28
+ length_penalty=1.4,
29
+ no_repeat_ngram_size=4,
30
+ encoder_no_repeat_ngram_size=5,
31
+ )
32
+
33
  def __init__(
34
  self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1", **kwargs
35
  ):
36
+ self.device = None
37
+ self.is_compiled = False
38
  self.logger = logging.getLogger(__name__)
39
+ self.init_model(model_name)
40
+
41
+ def init_model(self, model_name: str) -> None:
42
+ """
43
+ Initialize the model.
44
+
45
+ :param model_name: The name of the model to use.
46
+ """
47
+ # Free up memory
48
+ if torch.cuda.is_available():
49
+ torch.cuda.empty_cache()
50
+
51
+ self.logger.info(f"Setting model to {model_name}")
52
  self.model_name = model_name
53
+ self.aggregator = self._create_pipeline(model_name)
54
+ self._configure_model()
55
+ # update the generation config with the specific tokenizer
56
+ tokenizer_params = {
57
+ "decoder_start_token_id": 0
58
+ if "t5" in model_name.lower()
59
+ else self.aggregator.tokenizer.eos_token_id,
60
+ "eos_token_id": 1
61
+ if "t5" in model_name.lower()
62
+ else self.aggregator.tokenizer.eos_token_id,
63
+ "pad_token_id": 0
64
+ if "t5" in model_name.lower()
65
+ else self.aggregator.tokenizer.pad_token_id,
66
+ }
67
+ self.update_generation_config(**tokenizer_params)
68
+
69
+ def _create_pipeline(
70
+ self, model_name: str = "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
71
+ ) -> pipeline:
72
+ """
73
+ _create_pipeline creates a pipeline for the model.
74
+
75
+ :param str model_name: model name to use, default: "pszemraj/bart-large-mnli-dolly_hhrlhf-v1"
76
+ :return pipeline: the pipeline for the model
77
 
78
+ :raises Exception: if the pipeline cannot be created
79
+ """
80
+ self.device = 0 if torch.cuda.is_available() else -1
81
  try:
82
+ self.logger.info(
83
+ f"Creating pipeline with model {model_name} on device {self.device}"
84
+ )
85
+ return pipeline(
86
+ "text2text-generation",
87
+ model_name,
88
+ device=self.device,
89
+ torch_dtype=torch.float32,
90
+ )
91
  except Exception as e:
92
+ self.logger.error(f"Failed to create pipeline: {e}")
93
+ raise
94
+
95
+ def _configure_model(self):
96
+ """
97
+ Configure the model for generation.
98
+ """
99
  try:
100
+ self.aggregator.model = torch.compile(self.aggregator.model)
101
+ self.is_compiled = True
 
102
  except Exception as e:
103
+ self.logger.warning(f"Could not compile model with Torch 2.0: {e}")
104
+
105
+ if self.model_name not in self.CONFIGURED_MODELS:
106
+ self.logger.info("Setting generation config to general defaults")
107
+ self._set_default_generation_config()
108
+ else:
109
+ try:
110
+ self.logger.info("Loading generation config from hub")
111
+ self.aggregator.model.generation_config = (
112
+ GenerationConfig.from_pretrained(self.model_name)
113
+ )
114
+ except Exception as e:
115
+ self.logger.warning(
116
+ f"Could not load generation config, using defaults: {e}"
117
+ )
118
+ self._set_default_generation_config()
 
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  self.logger.info(self.aggregator.model.generation_config.to_json_string())
121
 
122
+ def _set_default_generation_config(self):
123
+ """
124
+ Set the default generation configuration for the model.
125
+ """
126
+ self.aggregator.model.generation_config = self.GENERIC_CONFIG
127
+
128
+ if "bart" in self.model_name.lower():
129
+ self.logger.info("Using BART model, updating generation config")
130
+ upd = {
131
+ "num_beams": 8,
132
+ "repetition_penalty": 1.3,
133
+ "length_penalty": 1.0,
134
+ "_from_model_config": False,
135
+ "max_new_tokens": 256,
136
+ "min_new_tokens": 32,
137
+ "no_repeat_ngram_size": 3,
138
+ "encoder_no_repeat_ngram_size": 6,
139
+ } # TODO: clean up
140
+ self.aggregator.model.generation_config.update(**upd)
141
+
142
+ if (
143
+ "large"
144
+ or "xl" in self.model_name.lower()
145
+ or compare_model_size(self.model_name, 500)
146
+ ):
147
+ upd = {"num_beams": 4}
148
+ self.update_generation_config(**upd)
149
+
150
  def update_generation_config(self, **kwargs):
151
+ """
152
+ Update the generation configuration with the specified parameters.
153
+
154
+ Args:
155
+ **kwargs: The parameters to update in the generation configuration.
156
+ """
157
+ self.logger.info(f"Updating generation config with {pp.pformat(kwargs)}")
158
+
159
+ self.aggregator.model.generation_config.update(**kwargs)
160
+
161
+ def update_loglevel(self, level: str = "INFO"):
162
+ """
163
+ Update the log level.
164
+
165
+ Args:
166
+ level (str): The log level to set. Defaults to "INFO".
167
+ """
168
+ self.logger.setLevel(level)
169
+
170
  def infer_aggregate(
171
  self,
172
  text_list: list,
173
+ instruction: str = DEFAULT_INSTRUCTION,
174
  **kwargs,
175
+ ) -> str:
176
+ f"""
177
+ Generate a summary of the specified texts.
178
+
179
+ Args:
180
+ text_list (list): The texts to summarize.
181
+ instruction (str): The instruction for the summary. Defaults to {self.DEFAULT_INSTRUCTION}.
182
+ **kwargs: Additional parameters to update in the generation configuration.
183
+
184
+ Returns:
185
+ The generated summary.
186
+ """
187
  joined_text = "\n".join(text_list)
188
  prompt = f"{instruction}\n\n{joined_text}\n"
189
  if kwargs:
190
  self.update_generation_config(**kwargs)
191
  st = time.perf_counter()
192
+ self.logger.info(f"inference on {len(text_list)} texts ...")
193
  result = self.aggregator(
194
  prompt,
195
  generation_config=self.aggregator.model.generation_config,
 
200
  )
201
  return result
202
 
203
+ def count_tokens(self, text: str) -> int:
204
+ """count the number of tokens in a text"""
205
  return (
206
  len(self.aggregator.tokenizer.encode(text, truncation=False, padding=False))
207
  if text
app.py CHANGED
@@ -1,5 +1,5 @@
1
  """
2
- app.py - the main module for the gradio app
3
 
4
  Usage:
5
  python app.py
@@ -19,6 +19,7 @@ import random
19
  import re
20
  import time
21
  from pathlib import Path
 
22
 
23
  os.environ["USE_TORCH"] = "1"
24
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -31,16 +32,18 @@ logging.basicConfig(
31
  import gradio as gr
32
  import nltk
33
  import torch
 
34
  from cleantext import clean
35
  from doctr.models import ocr_predictor
36
-
37
  from pdf2text import convert_PDF_to_Text
38
  from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
39
  from utils import (
 
40
  load_example_filenames,
41
  saves_summary,
42
  textlist2html,
43
  truncate_word_count,
 
44
  )
45
 
46
  _here = Path(__file__).parent
@@ -57,10 +60,76 @@ MODEL_OPTIONS = [
57
  "pszemraj/pegasus-x-large-book-summary",
58
  ] # models users can choose from
59
 
 
 
60
  # if duplicating space,, uncomment this line to adjust the max words
61
  # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
62
  # os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
  def predict(
66
  input_text: str,
@@ -128,6 +197,7 @@ def proc_submission(
128
  str in HTML format, string of the summary, str of score
129
  """
130
 
 
131
  settings = {
132
  "length_penalty": float(length_penalty),
133
  "repetition_penalty": float(repetition_penalty),
@@ -208,7 +278,6 @@ def proc_submission(
208
  # save to file
209
  settings["model_name"] = model_name
210
  saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
211
-
212
  return html, full_summary, scores_out, saved_file
213
 
214
 
@@ -361,7 +430,7 @@ if __name__ == "__main__":
361
  summarize_button = gr.Button(
362
  "Summarize!",
363
  variant="primary",
364
- )
365
  output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
366
  with gr.Column():
367
  gr.Markdown("#### Results & Scores")
@@ -384,11 +453,19 @@ if __name__ == "__main__":
384
  label="Summary Scores",
385
  placeholder="Summary scores will appear here",
386
  )
 
 
 
 
 
 
 
 
 
 
 
 
387
 
388
- gr.Markdown("#### **Summary Output**")
389
- summary_text = gr.HTML(
390
- label="Summary", value="<i>Summary will appear here!</i>"
391
- )
392
  gr.Markdown("---")
393
  with gr.Column():
394
  gr.Markdown("### Advanced Settings")
@@ -456,5 +533,9 @@ if __name__ == "__main__":
456
  ],
457
  outputs=[output_text, summary_text, summary_scores, text_file],
458
  )
459
-
460
- demo.launch(enable_queue=True)
 
 
 
 
 
1
  """
2
+ app.py - the main module for the gradio app for summarization
3
 
4
  Usage:
5
  python app.py
 
19
  import re
20
  import time
21
  from pathlib import Path
22
+ import pprint as pp
23
 
24
  os.environ["USE_TORCH"] = "1"
25
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
32
  import gradio as gr
33
  import nltk
34
  import torch
35
+ from aggregate import BatchAggregator
36
  from cleantext import clean
37
  from doctr.models import ocr_predictor
 
38
  from pdf2text import convert_PDF_to_Text
39
  from summarize import load_model_and_tokenizer, summarize_via_tokenbatches
40
  from utils import (
41
+ extract_batches,
42
  load_example_filenames,
43
  saves_summary,
44
  textlist2html,
45
  truncate_word_count,
46
+ remove_stagnant_files,
47
  )
48
 
49
  _here = Path(__file__).parent
 
60
  "pszemraj/pegasus-x-large-book-summary",
61
  ] # models users can choose from
62
 
63
+ SUMMARY_PLACEHOLDER = "<p><em>Output will appear below:</em></p>"
64
+
65
  # if duplicating space,, uncomment this line to adjust the max words
66
  # os.environ["APP_MAX_WORDS"] = str(2048) # set the max words to 2048
67
  # os.environ["APP_OCR_MAX_PAGES"] = str(40) # set the max pages to 40
68
 
69
+ aggregator = BatchAggregator("MBZUAI/LaMini-Flan-T5-783M")
70
+
71
+
72
+ def aggregate_text(
73
+ summary_text: str,
74
+ text_file: gr.inputs.File = None,
75
+ ):
76
+ """
77
+ Aggregate the text from the batches.
78
+
79
+ NOTE: you should probably include passing the BatchAggregator object as a parameter if using this code
80
+ outside of this file.
81
+ :param batches_html: The batches to aggregate, in html format
82
+ """
83
+ if summary_text is None or summary_text == SUMMARY_PLACEHOLDER:
84
+ logging.error("No text provided. Make sure a summary has been generated first.")
85
+ return "Error: No text provided. Make sure a summary has been generated first."
86
+
87
+ try:
88
+ extracted_batches = extract_batches(summary_text)
89
+ except Exception as e:
90
+ logging.info(summary_text)
91
+ logging.info(f"the batches html is: {type(summary_text)}")
92
+ return f"Error: unable to extract batches - check input: {e}"
93
+ if not extracted_batches:
94
+ logging.error("unable to extract batches - check input")
95
+ return "Error: unable to extract batches - check input"
96
+
97
+ out_path = None
98
+ if text_file is not None:
99
+ out_path = text_file.name # assuming name attribute stores the file path
100
+
101
+ content_batches = [batch["content"] for batch in extracted_batches]
102
+ full_summary = aggregator.infer_aggregate(content_batches)
103
+
104
+ # if a path that exists is provided, save the summary with markdown formatting
105
+ if out_path:
106
+ out_path = Path(out_path)
107
+
108
+ try:
109
+ with open(out_path, "a", encoding="utf-8") as f:
110
+ f.write("\n\n### Aggregate Summary\n\n")
111
+ f.write(
112
+ "- This is an instruction-based LLM aggregation of the previous 'summary batches'.\n"
113
+ )
114
+ f.write(f"- Aggregation model: {aggregator.model_name}\n\n")
115
+ f.write(f"{full_summary}\n\n")
116
+ logging.info(f"Updated {out_path} with aggregate summary")
117
+ except Exception as e:
118
+ logging.error(f"unable to update {out_path} with aggregate summary: {e}")
119
+
120
+ full_summary_html = f"""
121
+ <div style="
122
+ margin-bottom: 20px;
123
+ font-size: 18px;
124
+ line-height: 1.5em;
125
+ color: #333;
126
+ ">
127
+ <h2 style="font-size: 22px; color: #555;">Aggregate Summary:</h2>
128
+ <p style="white-space: pre-line;">{full_summary}</p>
129
+ </div>
130
+ """
131
+ return full_summary_html
132
+
133
 
134
  def predict(
135
  input_text: str,
 
197
  str in HTML format, string of the summary, str of score
198
  """
199
 
200
+ remove_stagnant_files() # clean up old files
201
  settings = {
202
  "length_penalty": float(length_penalty),
203
  "repetition_penalty": float(repetition_penalty),
 
278
  # save to file
279
  settings["model_name"] = model_name
280
  saved_file = saves_summary(summarize_output=_summaries, outpath=None, **settings)
 
281
  return html, full_summary, scores_out, saved_file
282
 
283
 
 
430
  summarize_button = gr.Button(
431
  "Summarize!",
432
  variant="primary",
433
+ ) # TODO: collapse button to be on same line as something else
434
  output_text = gr.HTML("<p><em>Output will appear below:</em></p>")
435
  with gr.Column():
436
  gr.Markdown("#### Results & Scores")
 
453
  label="Summary Scores",
454
  placeholder="Summary scores will appear here",
455
  )
456
+ with gr.Column():
457
+ gr.Markdown("#### **Summary Output**")
458
+ summary_text = gr.HTML(
459
+ label="Summary", value="<i>Summary will appear here!</i>"
460
+ )
461
+ with gr.Column():
462
+ gr.Markdown("##### **Aggregate Summary Batches**")
463
+ aggregate_button = gr.Button(
464
+ "Aggregate!",
465
+ variant="primary",
466
+ ) # TODO: collapse button to be on same line as something else
467
+ aggregated_summary = gr.HTML(label="Aggregate Summary", value="")
468
 
 
 
 
 
469
  gr.Markdown("---")
470
  with gr.Column():
471
  gr.Markdown("### Advanced Settings")
 
533
  ],
534
  outputs=[output_text, summary_text, summary_scores, text_file],
535
  )
536
+ aggregate_button.click(
537
+ fn=aggregate_text,
538
+ inputs=[summary_text, text_file],
539
+ outputs=[aggregated_summary],
540
+ )
541
+ demo.launch(enable_queue=True, share=True)