Spaces:
Runtime error
Runtime error
trminhnam20082002
commited on
Commit
·
55e492d
1
Parent(s):
59ae732
feat: add model
Browse files- .gitignore +2 -0
- app.py +150 -0
- utils.py +177 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
cache
|
2 |
+
**/__pycache__/
|
app.py
CHANGED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import streamlit as st
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from utils import (
|
8 |
+
load_model,
|
9 |
+
load_tokenizer,
|
10 |
+
make_input_sentence_from_strings,
|
11 |
+
generate_description,
|
12 |
+
)
|
13 |
+
|
14 |
+
st.set_page_config(
|
15 |
+
page_title="Table-to-text generation",
|
16 |
+
page_icon="📝",
|
17 |
+
layout="wide",
|
18 |
+
initial_sidebar_state="auto",
|
19 |
+
menu_items={
|
20 |
+
"Get Help": "https://huggingface.co/transformers/master/index.html",
|
21 |
+
"Report a bug": "https://github.com",
|
22 |
+
}, # hide the "Made with Streamlit" footer
|
23 |
+
)
|
24 |
+
|
25 |
+
st.title("Table-to-text generation with multilingual pre-trained models")
|
26 |
+
st.markdown(
|
27 |
+
"""
|
28 |
+
This is a demo of table-to-text generation with multilingual pre-trained models.
|
29 |
+
The models are trained on our custom dataset, which is sampling from Viettel Report Template and generated description by ChatGPT.
|
30 |
+
"""
|
31 |
+
)
|
32 |
+
|
33 |
+
st.sidebar.title("Settings")
|
34 |
+
model_name = st.sidebar.selectbox(
|
35 |
+
"Model name",
|
36 |
+
[
|
37 |
+
"vinai/bartpho-syllable",
|
38 |
+
"vinai/bartpho-syllable-base",
|
39 |
+
"google/byt5-base",
|
40 |
+
"google/byt5-small",
|
41 |
+
"facebook/mbart-large-50",
|
42 |
+
],
|
43 |
+
)
|
44 |
+
|
45 |
+
if torch.cuda.is_available():
|
46 |
+
device = "cuda" if st.sidebar.checkbox("Use GPU", False) else "cpu"
|
47 |
+
else:
|
48 |
+
st.sidebar.checkbox("Use GPU", False, disabled=True)
|
49 |
+
device = "cpu"
|
50 |
+
max_len = st.sidebar.slider("Max length", 32, 512, 256, 32)
|
51 |
+
beam_size = st.sidebar.slider("Beam size", 1, 10, 3, 1)
|
52 |
+
tokenizer = load_tokenizer(model_name)
|
53 |
+
model = load_model(model_name, device)
|
54 |
+
|
55 |
+
# create a text input box for each of the following item
|
56 |
+
# CHỈ TIÊU ĐƠN VỊ ĐIỀU KIỆN KPI mục tiêu tháng Tháng 9.2022 Đánh giá T8.2022 So sánh T8.2022 Tăng giảm T9.2021 So sánh T9.2021 Tăng giảm
|
57 |
+
|
58 |
+
objective_name = st.text_input("CHỈ TIÊU", "")
|
59 |
+
(unit_col, condition_col, kpi_target_col) = st.columns(3)
|
60 |
+
with unit_col:
|
61 |
+
unit = st.text_input("ĐƠN VỊ", "")
|
62 |
+
with condition_col:
|
63 |
+
condition = st.selectbox("ĐIỀU KIỆN", [">=", "<=", None])
|
64 |
+
with kpi_target_col:
|
65 |
+
kpi_target = st.text_input("KPI mục tiêu tháng", "")
|
66 |
+
|
67 |
+
current_date_col, real_value_col, evaluation_col = st.columns(3)
|
68 |
+
with current_date_col:
|
69 |
+
current_date = st.date_input(
|
70 |
+
"Thời gian báo cáo", value=None, min_value=None, max_value=None, key=None
|
71 |
+
)
|
72 |
+
current_time = [int(x) for x in current_date.__str__().split("-")[:2]]
|
73 |
+
with real_value_col:
|
74 |
+
real_value = st.text_input(f"T{current_time[1]}.{current_time[0]} thực tế", "")
|
75 |
+
with evaluation_col:
|
76 |
+
evaluation_value = st.selectbox(
|
77 |
+
"Đánh giá",
|
78 |
+
["Đạt", "Không đạt", "Theo dõi"],
|
79 |
+
index=2 if (kpi_target == "" or condition is None) else 0,
|
80 |
+
)
|
81 |
+
# current_time is in format [year, month, day]
|
82 |
+
|
83 |
+
previous_month = (
|
84 |
+
[current_time[0], current_time[1] - 1]
|
85 |
+
if current_time[1] > 1
|
86 |
+
else [current_time[0] - 1, 12]
|
87 |
+
)
|
88 |
+
|
89 |
+
previous_year = [current_time[0] - 1, current_time[1]]
|
90 |
+
|
91 |
+
(
|
92 |
+
previous_month_value_col,
|
93 |
+
previous_month_compare_col,
|
94 |
+
previous_year_value_col,
|
95 |
+
previous_year_compare_col,
|
96 |
+
) = st.columns(4)
|
97 |
+
with previous_month_value_col:
|
98 |
+
previous_month_value = st.text_input(
|
99 |
+
f"T{previous_month[1]}.{previous_month[0]}", ""
|
100 |
+
)
|
101 |
+
with previous_month_compare_col:
|
102 |
+
previous_month_compare = st.text_input(
|
103 |
+
f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
|
104 |
+
float(real_value) - float(previous_month_value)
|
105 |
+
if previous_month_value != ""
|
106 |
+
else "",
|
107 |
+
# disabled=True,
|
108 |
+
)
|
109 |
+
with previous_year_value_col:
|
110 |
+
previous_year_value = st.text_input(f"T{previous_year[1]}.{previous_year[0]}", "")
|
111 |
+
with previous_year_compare_col:
|
112 |
+
previous_year_compare = st.text_input(
|
113 |
+
f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
|
114 |
+
float(real_value) - float(previous_year_value)
|
115 |
+
if previous_year_value != ""
|
116 |
+
else "",
|
117 |
+
# disabled=True,
|
118 |
+
)
|
119 |
+
|
120 |
+
|
121 |
+
data = {
|
122 |
+
"CHỈ TIÊU": objective_name,
|
123 |
+
"ĐƠN VỊ": unit,
|
124 |
+
"ĐIỀU KIỆN": condition,
|
125 |
+
"KPI mục tiêu tháng": kpi_target,
|
126 |
+
"Đánh giá": evaluation_value,
|
127 |
+
"Thời gian báo cáo": current_time,
|
128 |
+
f"T{current_time[1]}.{current_time[0]} thực tế": real_value,
|
129 |
+
"Previous month value key": f"T{previous_month[1]}.{previous_month[0]}",
|
130 |
+
f"T{previous_month[1]}.{previous_month[0]}": previous_month_value,
|
131 |
+
"Previous year value key": f"T{previous_year[1]}.{previous_year[0]}",
|
132 |
+
f"T{previous_year[1]}.{previous_year[0]}": previous_year_value,
|
133 |
+
"Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
|
134 |
+
f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare,
|
135 |
+
"Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
|
136 |
+
f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm": previous_year_compare,
|
137 |
+
"Previous month": previous_month,
|
138 |
+
"Previous year": previous_year,
|
139 |
+
}
|
140 |
+
|
141 |
+
|
142 |
+
if st.button("Generate"):
|
143 |
+
with st.spinner("Generating..."):
|
144 |
+
input_string = make_input_sentence_from_strings(data)
|
145 |
+
print(input_string)
|
146 |
+
descriptions = generate_description(
|
147 |
+
input_string, model, tokenizer, device, max_len, model_name, beam_size
|
148 |
+
)
|
149 |
+
|
150 |
+
st.success(descriptions)
|
utils.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import torch
|
6 |
+
from transformers import (
|
7 |
+
AutoTokenizer,
|
8 |
+
AutoModel,
|
9 |
+
T5ForConditionalGeneration,
|
10 |
+
MBartForConditionalGeneration,
|
11 |
+
AutoModelForSeq2SeqLM,
|
12 |
+
)
|
13 |
+
from tqdm.auto import tqdm
|
14 |
+
import streamlit as st
|
15 |
+
from typing import Dict, List
|
16 |
+
|
17 |
+
|
18 |
+
def get_model(args):
|
19 |
+
print(f"Using model {args.model_name}")
|
20 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name)
|
21 |
+
model.to(args.device)
|
22 |
+
|
23 |
+
if args.load_model_path:
|
24 |
+
print(f"Loading model from {args.load_model_path}")
|
25 |
+
model.load_state_dict(
|
26 |
+
torch.load(args.load_model_path, map_location=torch.device(args.device))
|
27 |
+
)
|
28 |
+
|
29 |
+
return model
|
30 |
+
|
31 |
+
|
32 |
+
@st.cache(allow_output_mutation=True)
|
33 |
+
def load_model(model_name, device):
|
34 |
+
print(f"Using model {model_name}")
|
35 |
+
os.makedirs("cache", exist_ok=True)
|
36 |
+
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, cache_dir="cache")
|
37 |
+
model.to(device)
|
38 |
+
|
39 |
+
model_name = model_name.split("/")[-1]
|
40 |
+
load_model_path = os.path.join("models", f"{model_name}-best_loss.bin")
|
41 |
+
print(f"Loading model from {load_model_path}")
|
42 |
+
model.load_state_dict(
|
43 |
+
torch.load(load_model_path, map_location=torch.device(device))
|
44 |
+
)
|
45 |
+
|
46 |
+
return model
|
47 |
+
|
48 |
+
|
49 |
+
@st.cache(allow_output_mutation=True)
|
50 |
+
def load_tokenizer(model_name):
|
51 |
+
print(f"Loading tokenizer {model_name}")
|
52 |
+
if "mbart" in model_name.lower():
|
53 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
54 |
+
model_name, src_lang="vi_VN", tgt_lang="vi_VN"
|
55 |
+
)
|
56 |
+
# tokenizer.src_lang = "vi_VN"
|
57 |
+
# tokenizer.tgt_lang = "vi_VN"
|
58 |
+
else:
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
60 |
+
|
61 |
+
return tokenizer
|
62 |
+
|
63 |
+
|
64 |
+
def prepare_batch_model_inputs(batch, tokenizer, max_len, is_train=False, device="cpu"):
|
65 |
+
inputs = tokenizer(
|
66 |
+
batch["src"],
|
67 |
+
text_target=batch["tgt"] if is_train else None,
|
68 |
+
padding="longest",
|
69 |
+
max_length=max_len,
|
70 |
+
truncation=True,
|
71 |
+
return_tensors="pt",
|
72 |
+
)
|
73 |
+
|
74 |
+
for k, v in inputs.items():
|
75 |
+
inputs[k] = v.to(device)
|
76 |
+
|
77 |
+
return inputs
|
78 |
+
|
79 |
+
|
80 |
+
def prepare_single_model_inputs(src, tokenizer, max_len, device="cpu"):
|
81 |
+
inputs = tokenizer(
|
82 |
+
src,
|
83 |
+
padding="longest",
|
84 |
+
max_length=max_len,
|
85 |
+
truncation=True,
|
86 |
+
return_tensors="pt",
|
87 |
+
)
|
88 |
+
|
89 |
+
for k, v in inputs.items():
|
90 |
+
inputs[k] = v.to(device)
|
91 |
+
|
92 |
+
return inputs
|
93 |
+
|
94 |
+
|
95 |
+
def make_input_sentence_from_strings(data):
|
96 |
+
# data = {
|
97 |
+
# "CHỈ TIÊU": objective_name,
|
98 |
+
# "ĐƠN VỊ": unit,
|
99 |
+
# "ĐIỀU KIỆN": condition,
|
100 |
+
# "KPI mục tiêu tháng": kpi_target,
|
101 |
+
# "Đánh giá": evaluation_value,
|
102 |
+
# "Thời gian báo cáo": current_time,
|
103 |
+
# f"T{current_time[1]}.{current_time[0]} thực tế": real_value,
|
104 |
+
# "Previous month value key": f"T{previous_month[1]}.{previous_month[0]}",
|
105 |
+
# f"T{previous_month[1]}.{previous_month[0]}": previous_month_value,
|
106 |
+
# "Previous year value key": f"T{previous_year[1]}.{previous_year[0]}",
|
107 |
+
# f"T{previous_year[1]}.{previous_year[0]}": previous_year_value,
|
108 |
+
# "Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm",
|
109 |
+
# f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare,
|
110 |
+
# "Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm",
|
111 |
+
# "Previous month": previous_month,
|
112 |
+
# "Previous year": previous_year,
|
113 |
+
# }
|
114 |
+
|
115 |
+
previous_month_value_key = data["Previous month value key"]
|
116 |
+
previous_year_value_key = data["Previous year value key"]
|
117 |
+
objective_name = data["CHỈ TIÊU"]
|
118 |
+
unit = data["ĐƠN VỊ"]
|
119 |
+
condition = data["ĐIỀU KIỆN"]
|
120 |
+
kpi_target = data["KPI mục tiêu tháng"]
|
121 |
+
current_time = data["Thời gian báo cáo"]
|
122 |
+
real_value = data[f"T{current_time[1]}.{current_time[0]} thực tế"]
|
123 |
+
evaluation_value = data["Đánh giá"]
|
124 |
+
previous_month_value = data[previous_month_value_key]
|
125 |
+
previous_year_value = data[previous_year_value_key]
|
126 |
+
previous_month_compare_key = data["Previous month compare key"]
|
127 |
+
previous_year_compare_key = data["Previous year compare key"]
|
128 |
+
previous_month_compare = data[previous_month_compare_key]
|
129 |
+
previous_year_compare = data[previous_year_compare_key]
|
130 |
+
previous_month = data["Previous month"]
|
131 |
+
previous_year = data["Previous year"]
|
132 |
+
|
133 |
+
# make a template string from the following example:
|
134 |
+
# """{"CHỈ TIÊU": "Tỷ lệ kết nối thành công đến tổng đài - KHCN_Di động Vip", "ĐƠN VỊ": "%", "ĐIỀU KIỆN": ">=", "KPI mục tiêu tháng": 95.0, "Tháng 9.2022": 97.5, "Đánh giá": "Đạt", "T8.2022": 96.6, "So sánh T8.2022 Tăng giảm": 1.0, "T9.2021": 96.8, "So sánh T9.2021 Tăng giảm": 0.8}"""
|
135 |
+
template_str = '"CHỈ TIÊU": "{}", "ĐƠN VỊ": "{}", "ĐIỀU KIỆN": "{}", "KPI mục tiêu tháng": {}, "Tháng {}.{}": {}, "Đánh giá": "{}", "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}, "T{}.{}": {}, "So sánh T{}.{} Tăng giảm": {}'
|
136 |
+
return template_str.format(
|
137 |
+
objective_name,
|
138 |
+
unit,
|
139 |
+
condition,
|
140 |
+
kpi_target,
|
141 |
+
current_time[1],
|
142 |
+
current_time[0],
|
143 |
+
real_value,
|
144 |
+
evaluation_value,
|
145 |
+
previous_month[1],
|
146 |
+
previous_month[0],
|
147 |
+
previous_month_value,
|
148 |
+
previous_month[1],
|
149 |
+
previous_month[0],
|
150 |
+
previous_month_compare,
|
151 |
+
previous_year[1],
|
152 |
+
previous_year[0],
|
153 |
+
previous_year_value,
|
154 |
+
previous_year[1],
|
155 |
+
previous_year[0],
|
156 |
+
previous_year_compare,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
@torch.no_grad()
|
161 |
+
def generate_description(
|
162 |
+
input_string, model, tokenizer, device, max_len, model_name, beam_size
|
163 |
+
):
|
164 |
+
inputs = prepare_single_model_inputs(
|
165 |
+
input_string, tokenizer, max_len=max_len, device=device
|
166 |
+
)
|
167 |
+
if "mbart" in model_name.lower():
|
168 |
+
inputs["forced_bos_token_id"] = tokenizer.lang_code_to_id["vi_VN"]
|
169 |
+
outputs = model.generate(
|
170 |
+
**inputs,
|
171 |
+
max_length=max_len,
|
172 |
+
num_beams=beam_size,
|
173 |
+
# early_stopping=True,
|
174 |
+
)
|
175 |
+
return tokenizer.batch_decode(
|
176 |
+
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=True
|
177 |
+
)
|