Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
import streamlit as st | |
import pandas as pd | |
import torch | |
from utils import ( | |
load_model, | |
load_tokenizer, | |
make_input_sentence_from_strings, | |
generate_description, | |
) | |
st.set_page_config( | |
page_title="Table-to-text generation", | |
page_icon="📝", | |
layout="wide", | |
initial_sidebar_state="auto", | |
menu_items={ | |
"Get Help": "https://huggingface.co/transformers/master/index.html", | |
"Report a bug": "https://github.com", | |
}, # hide the "Made with Streamlit" footer | |
) | |
st.title("Table-to-text generation with multilingual pre-trained models") | |
st.markdown( | |
""" | |
This is a demo of table-to-text generation with multilingual pre-trained models. | |
The models are trained on our custom dataset, which is sampling from Viettel Report Template and generated description by ChatGPT. | |
""" | |
) | |
st.sidebar.title("Settings") | |
model_name = st.sidebar.selectbox( | |
"Model name", | |
[ | |
"vinai/bartpho-syllable", | |
"vinai/bartpho-syllable-base", | |
"google/byt5-base", | |
"google/byt5-small", | |
"facebook/mbart-large-50", | |
], | |
) | |
if torch.cuda.is_available(): | |
device = "cuda" if st.sidebar.checkbox("Use GPU", False) else "cpu" | |
else: | |
st.sidebar.checkbox("Use GPU", False, disabled=True) | |
device = "cpu" | |
max_len = st.sidebar.slider("Max length", 32, 512, 256, 32) | |
beam_size = st.sidebar.slider("Beam size", 1, 10, 3, 1) | |
# create a text input box for each of the following item | |
# CHỈ TIÊU ĐƠN VỊ ĐIỀU KIỆN KPI mục tiêu tháng Tháng 9.2022 Đánh giá T8.2022 So sánh T8.2022 Tăng giảm T9.2021 So sánh T9.2021 Tăng giảm | |
objective_name = st.text_input("CHỈ TIÊU", "") | |
(unit_col, condition_col, kpi_target_col) = st.columns(3) | |
with unit_col: | |
unit = st.text_input("ĐƠN VỊ", "") | |
with condition_col: | |
condition = st.selectbox("ĐIỀU KIỆN", [">=", "<=", None]) | |
with kpi_target_col: | |
kpi_target = st.text_input("KPI mục tiêu tháng", "") | |
current_date_col, real_value_col, evaluation_col = st.columns(3) | |
with current_date_col: | |
current_date = st.date_input( | |
"Thời gian báo cáo", value=None, min_value=None, max_value=None, key=None | |
) | |
current_time = [int(x) for x in current_date.__str__().split("-")[:2]] | |
with real_value_col: | |
real_value = st.text_input(f"T{current_time[1]}.{current_time[0]} thực tế", "") | |
with evaluation_col: | |
evaluation_value = st.selectbox( | |
"Đánh giá", | |
["Đạt", "Không đạt", "Theo dõi"], | |
index=2 if (kpi_target == "" or condition is None) else 0, | |
) | |
# current_time is in format [year, month, day] | |
previous_month = ( | |
[current_time[0], current_time[1] - 1] | |
if current_time[1] > 1 | |
else [current_time[0] - 1, 12] | |
) | |
previous_year = [current_time[0] - 1, current_time[1]] | |
( | |
previous_month_value_col, | |
previous_month_compare_col, | |
previous_year_value_col, | |
previous_year_compare_col, | |
) = st.columns(4) | |
with previous_month_value_col: | |
previous_month_value = st.text_input( | |
f"T{previous_month[1]}.{previous_month[0]}", "" | |
) | |
with previous_month_compare_col: | |
previous_month_compare = st.text_input( | |
f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm", | |
float(real_value) - float(previous_month_value) | |
if previous_month_value != "" | |
else "", | |
# disabled=True, | |
) | |
with previous_year_value_col: | |
previous_year_value = st.text_input(f"T{previous_year[1]}.{previous_year[0]}", "") | |
with previous_year_compare_col: | |
previous_year_compare = st.text_input( | |
f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm", | |
float(real_value) - float(previous_year_value) | |
if previous_year_value != "" | |
else "", | |
# disabled=True, | |
) | |
data = { | |
"CHỈ TIÊU": objective_name, | |
"ĐƠN VỊ": unit, | |
"ĐIỀU KIỆN": condition, | |
"KPI mục tiêu tháng": kpi_target, | |
"Đánh giá": evaluation_value, | |
"Thời gian báo cáo": current_time, | |
f"T{current_time[1]}.{current_time[0]} thực tế": real_value, | |
"Previous month value key": f"T{previous_month[1]}.{previous_month[0]}", | |
f"T{previous_month[1]}.{previous_month[0]}": previous_month_value, | |
"Previous year value key": f"T{previous_year[1]}.{previous_year[0]}", | |
f"T{previous_year[1]}.{previous_year[0]}": previous_year_value, | |
"Previous month compare key": f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm", | |
f"So sánh T{previous_month[1]}.{previous_month[0]} Tăng giảm": previous_month_compare, | |
"Previous year compare key": f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm", | |
f"So sánh T{previous_year[1]}.{previous_year[0]} Tăng giảm": previous_year_compare, | |
"Previous month": previous_month, | |
"Previous year": previous_year, | |
} | |
tokenizer = load_tokenizer(model_name) | |
model = load_model(model_name, device) | |
if st.button("Generate"): | |
if objective_name == "": | |
st.error("Please input objective name") | |
elif unit == "": | |
st.error("Please input unit") | |
else: | |
with st.spinner("Generating..."): | |
input_string = make_input_sentence_from_strings(data) | |
print(input_string) | |
descriptions = generate_description( | |
input_string, model, tokenizer, device, max_len, model_name, beam_size | |
) | |
st.success(descriptions) | |