trminhnam20082002 commited on
Commit
55e492d
·
1 Parent(s): 59ae732

feat: add model

Browse files
Files changed (3) hide show
  1. .gitignore +2 -0
  2. app.py +150 -0
  3. utils.py +177 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ cache
2
+ **/__pycache__/
app.py CHANGED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import streamlit as st
4
+ import pandas as pd
5
+ import torch
6
+
7
+ from utils import (
8
+ load_model,
9
+ load_tokenizer,
10
+ make_input_sentence_from_strings,
11
+ generate_description,
12
+ )
13
+
14
+ st.set_page_config(
15
+ page_title="Table-to-text generation",
16
+ page_icon="📝",
17
+ layout="wide",
18
+ initial_sidebar_state="auto",
19
+ menu_items={
20
+ "Get Help": "https://huggingface.co/transformers/master/index.html",
21
+ "Report a bug": "https://github.com",
22
+ }, # hide the "Made with Streamlit" footer
23
+ )
24
+
25
+ st.title("Table-to-text generation with multilingual pre-trained models")
26
+ st.markdown(
27
+ """
28
+ This is a demo of table-to-text generation with multilingual pre-trained models.
29
+ The models are trained on our custom dataset, which is sampling from Viettel Report Template and generated description by ChatGPT.
30
+ """
31
+ )
32
+
33
+ st.sidebar.title("Settings")
34
+ model_name = st.sidebar.selectbox(
35
+ "Model name",
36
+ [
37
+ "vinai/bartpho-syllable",
38
+ "vinai/bartpho-syllable-base",
39
+ "google/byt5-base",
40
+ "google/byt5-small",
41
+ "facebook/mbart-large-50",
42
+ ],
43
+ )
44
+
45
+ if torch.cuda.is_available():
46
+ device = "cuda" if st.sidebar.checkbox("Use GPU", False) else "cpu"
47
+ else:
48
+ st.sidebar.checkbox("Use GPU", False, disabled=True)
49
+ device = "cpu"
50
+ max_len = st.sidebar.slider("Max length", 32, 512, 256, 32)
51
+ beam_size = st.sidebar.slider("Beam size", 1, 10, 3, 1)
52
+ tokenizer = load_tokenizer(model_name)
53
+ model = load_model(model_name, device)
54
+
55
+ # create a text input box for each of the following item
56
+ # CHỈ TIÊU ĐƠN VỊ ĐIỀU KIỆN KPI mục tiêu tháng Tháng 9.2022 Đánh giá T8.2022 So sánh T8.2022 Tăng giảm T9.2021 So sánh T9.2021 Tăng giảm
57
+
58
+ objective_name = st.text_input("CHỈ TIÊU", "")
59
+ (unit_col, condition_col, kpi_target_col) = st.columns(3)
60
+ with unit_col:
61
+ unit = st.text_input("ĐƠN VỊ", "")
62
+ with condition_col:
63
+ condition = st.selectbox("ĐIỀU KIỆN", [">=", "<=", None])
64
+ with kpi_target_col:
65
+ kpi_target = st.text_input("KPI mục tiêu tháng", "")
66
+
67
+ current_date_col, real_value_col, evaluation_col = st.columns(3)
68
+ with current_date_col:
69
+ current_date = st.date_input(
70
+ "Thời gian báo cáo", value=None, min_value=None, max_value=None, key=None
71
+ )
72
+ current_time = [int(x) for x in current_date.__str__().split("-")[:2]]
73
+ with real_value_col:
74
+ real_value = st.text_input(f"T{current_time[1]}.{current_time[0]} thực tế", "")
75
+ with evaluation_col:
76
+ evaluation_value = st.selectbox(
77
+ "Đánh giá",
78
+ ["Đạt", "Không đạt", "Theo dõi"],
79
+ index=2 if (kpi_target == "" or condition is None) else 0,
80
+ )
81
+ # current_time is in format [year, month, day]
82
+
83
+ previous_month = (
84
+ [current_time[0], current_time[1] - 1]
85
+ if current_time[1] > 1
86
+ else [current_time[0] - 1, 12]
87
+ )
88
+
89
+ previous_year = [current_time[0] - 1, current_time[1]]
90
+
91
+ (
92
+ previous_month_value_col,
93
+ previous_month_compare_col,
94
+ previous_year_value_col,
95
+ previous_year_compare_col,
96
+ ) = st.columns(4)
97
+ with previous_month_value_col:
98
+ previous_month_value = st.text_input(
99
+ f"T{previous_month[1]}.{previous_month[0]}", ""
100
+ )
101
+ with previous_month_compare_col:
102
+ previous_month_compare = st.text_input(
103
+ f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
104
+ float(real_value) - float(previous_month_value)
105
+ if previous_month_value != ""
106
+ else "",
107
+ # disabled=True,
108
+ )
109
+ with previous_year_value_col:
110
+ previous_year_value = st.text_input(f"T{previous_year[1]}.{previous_year[0]}", "")
111
+ with previous_year_compare_col:
112
+ previous_year_compare = st.text_input(
113
+ f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
114
+ float(real_value) - float(previous_year_value)
115
+ if previous_year_value != ""
116
+ else "",
117
+ # disabled=True,
118
+ )
119
+
120
+
121
+ data = {
122
+ "CHỈ TIÊU": objective_name,
123
+ "ĐƠN VỊ": unit,
124
+ "ĐIỀU KIỆN": condition,
125
+ "KPI mục tiêu tháng": kpi_target,
126
+ "Đánh giá": evaluation_value,
127
+ "Thời gian báo cáo": current_time,
128
+ f"T{current_time[1]}.{current_time[0]} thực tế": real_value,
129
+ "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}",
130
+ f"T{previous_month[1]}.{previous_month[0]}": previous_month_value,
131
+ "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}",
132
+ f"T{previous_year[1]}.{previous_year[0]}": previous_year_value,
133
+ "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
134
+ f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare,
135
+ "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
136
+ f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm": previous_year_compare,
137
+ "Previous month": previous_month,
138
+ "Previous year": previous_year,
139
+ }
140
+
141
+
142
+ if st.button("Generate"):
143
+ with st.spinner("Generating..."):
144
+ input_string = make_input_sentence_from_strings(data)
145
+ print(input_string)
146
+ descriptions = generate_description(
147
+ input_string, model, tokenizer, device, max_len, model_name, beam_size
148
+ )
149
+
150
+ st.success(descriptions)
utils.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import re
5
+ import torch
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModel,
9
+ T5ForConditionalGeneration,
10
+ MBartForConditionalGeneration,
11
+ AutoModelForSeq2SeqLM,
12
+ )
13
+ from tqdm.auto import tqdm
14
+ import streamlit as st
15
+ from typing import Dict, List
16
+
17
+
18
+ def get_model(args):
19
+ print(f"Using model {args.model_name}")
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
21
+ model.to(args.device)
22
+
23
+ if args.load_model_path:
24
+ print(f"Loading model from {args.load_model_path}")
25
+ model.load_state_dict(
26
+ torch.load(args.load_model_path, map_location=torch.device(args.device))
27
+ )
28
+
29
+ return model
30
+
31
+
32
+ @st.cache(allow_output_mutation=True)
33
+ def load_model(model_name, device):
34
+ print(f"Using model {model_name}")
35
+ os.makedirs("cache", exist_ok=True)
36
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="cache")
37
+ model.to(device)
38
+
39
+ model_name = model_name.split("/")[-1]
40
+ load_model_path = os.path.join("models", f"{model_name}-best_loss.bin")
41
+ print(f"Loading model from {load_model_path}")
42
+ model.load_state_dict(
43
+ torch.load(load_model_path, map_location=torch.device(device))
44
+ )
45
+
46
+ return model
47
+
48
+
49
+ @st.cache(allow_output_mutation=True)
50
+ def load_tokenizer(model_name):
51
+ print(f"Loading tokenizer {model_name}")
52
+ if "mbart" in model_name.lower():
53
+ tokenizer = AutoTokenizer.from_pretrained(
54
+ model_name, src_lang="vi_VN", tgt_lang="vi_VN"
55
+ )
56
+ # tokenizer.src_lang = "vi_VN"
57
+ # tokenizer.tgt_lang = "vi_VN"
58
+ else:
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
60
+
61
+ return tokenizer
62
+
63
+
64
+ def prepare_batch_model_inputs(batch, tokenizer, max_len, is_train=False, device="cpu"):
65
+ inputs = tokenizer(
66
+ batch["src"],
67
+ text_target=batch["tgt"] if is_train else None,
68
+ padding="longest",
69
+ max_length=max_len,
70
+ truncation=True,
71
+ return_tensors="pt",
72
+ )
73
+
74
+ for k, v in inputs.items():
75
+ inputs[k] = v.to(device)
76
+
77
+ return inputs
78
+
79
+
80
+ def prepare_single_model_inputs(src, tokenizer, max_len, device="cpu"):
81
+ inputs = tokenizer(
82
+ src,
83
+ padding="longest",
84
+ max_length=max_len,
85
+ truncation=True,
86
+ return_tensors="pt",
87
+ )
88
+
89
+ for k, v in inputs.items():
90
+ inputs[k] = v.to(device)
91
+
92
+ return inputs
93
+
94
+
95
+ def make_input_sentence_from_strings(data):
96
+ # data = {
97
+ # "CHỈ TIÊU": objective_name,
98
+ # "ĐƠN VỊ": unit,
99
+ # "ĐIỀU KIỆN": condition,
100
+ # "KPI mục tiêu tháng": kpi_target,
101
+ # "Đánh giá": evaluation_value,
102
+ # "Thời gian báo cáo": current_time,
103
+ # f"T{current_time[1]}.{current_time[0]} thực tế": real_value,
104
+ # "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}",
105
+ # f"T{previous_month[1]}.{previous_month[0]}": previous_month_value,
106
+ # "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}",
107
+ # f"T{previous_year[1]}.{previous_year[0]}": previous_year_value,
108
+ # "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
109
+ # f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare,
110
+ # "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
111
+ # "Previous month": previous_month,
112
+ # "Previous year": previous_year,
113
+ # }
114
+
115
+ previous_month_value_key = data["Previous month value key"]
116
+ previous_year_value_key = data["Previous year value key"]
117
+ objective_name = data["CHỈ TIÊU"]
118
+ unit = data["ĐƠN VỊ"]
119
+ condition = data["ĐIỀU KIỆN"]
120
+ kpi_target = data["KPI mục tiêu tháng"]
121
+ current_time = data["Thời gian báo cáo"]
122
+ real_value = data[f"T{current_time[1]}.{current_time[0]} thực tế"]
123
+ evaluation_value = data["Đánh giá"]
124
+ previous_month_value = data[previous_month_value_key]
125
+ previous_year_value = data[previous_year_value_key]
126
+ previous_month_compare_key = data["Previous month compare key"]
127
+ previous_year_compare_key = data["Previous year compare key"]
128
+ previous_month_compare = data[previous_month_compare_key]
129
+ previous_year_compare = data[previous_year_compare_key]
130
+ previous_month = data["Previous month"]
131
+ previous_year = data["Previous year"]
132
+
133
+ # make a template string from the following example:
134
+ # """{"CHỈ TIÊU": "Tỷ lệ kết nối thành công đến tổng đài - KHCN_Di động Vip", "ĐƠN VỊ": "%", "ĐIỀU KIỆN": ">=", "KPI mục tiêu tháng": 95.0, "Tháng 9.2022": 97.5, "Đánh giá": "Đạt", "T8.2022": 96.6, "So sánh T8.2022 Tăng giảm": 1.0, "T9.2021": 96.8, "So sánh T9.2021 Tăng giảm": 0.8}"""
135
+ template_str = '"CHỈ TIÊU": "{}", "ĐƠN VỊ": "{}", "ĐIỀU KIỆN": "{}", "KPI mục tiêu tháng": {}, "Tháng {}.{}": {}, "Đánh giá": "{}", "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}, "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}'
136
+ return template_str.format(
137
+ objective_name,
138
+ unit,
139
+ condition,
140
+ kpi_target,
141
+ current_time[1],
142
+ current_time[0],
143
+ real_value,
144
+ evaluation_value,
145
+ previous_month[1],
146
+ previous_month[0],
147
+ previous_month_value,
148
+ previous_month[1],
149
+ previous_month[0],
150
+ previous_month_compare,
151
+ previous_year[1],
152
+ previous_year[0],
153
+ previous_year_value,
154
+ previous_year[1],
155
+ previous_year[0],
156
+ previous_year_compare,
157
+ )
158
+
159
+
160
+ @torch.no_grad()
161
+ def generate_description(
162
+ input_string, model, tokenizer, device, max_len, model_name, beam_size
163
+ ):
164
+ inputs = prepare_single_model_inputs(
165
+ input_string, tokenizer, max_len=max_len, device=device
166
+ )
167
+ if "mbart" in model_name.lower():
168
+ inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"]
169
+ outputs = model.generate(
170
+ **inputs,
171
+ max_length=max_len,
172
+ num_beams=beam_size,
173
+ # early_stopping=True,
174
+ )
175
+ return tokenizer.batch_decode(
176
+ outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
177
+ )