|
import streamlit as st |
|
from transformers import pipeline |
|
from io import StringIO |
|
|
|
unmasker = pipeline('fill-mask', model='dsfsi/zabantu-ven-120m') |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
def fill_mask(sentences): |
|
results = {} |
|
warnings = [] |
|
for sentence in sentences: |
|
if "<mask>" in sentence: |
|
unmasked = unmasker(sentence) |
|
results[sentence] = unmasked |
|
else: |
|
warnings.append(f"Warning: No <mask> token found in sentence: {sentence}") |
|
return results, warnings |
|
|
|
def replace_mask(sentence, predicted_word): |
|
return sentence.replace("<mask>", f"**{predicted_word}**") |
|
|
|
st.title("Fill Mask | Zabantu-ven-120m") |
|
st.write(f"") |
|
|
|
st.markdown("This is a variant of Zabantu pre-trained on a monolingual dataset of Tshivenda(ven) sentences on a transformer network with 120 million traininable parameters.") |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
if 'text_input' not in st.session_state: |
|
st.session_state['text_input'] = "" |
|
|
|
if 'warnings' not in st.session_state: |
|
st.session_state['warnings'] = [] |
|
|
|
with col1: |
|
with st.container(border=True): |
|
st.markdown("Input :clipboard:") |
|
|
|
select_options = ['Choose option', 'Enter text input', 'Upload a file(csv/txt)'] |
|
sample_sentence = "Vhana vhane vha kha ḓi bva u bebwa vha kha khombo ya u <mask> nga Listeriosis." |
|
|
|
option_selected = st.selectbox(f"Select an input option:", select_options, index=0) |
|
|
|
if option_selected == 'Enter text input': |
|
text_input = st.text_area( |
|
"Enter sentences with <mask> token(one sentence per line):", |
|
value=st.session_state['text_input'] |
|
) |
|
|
|
input_sentences = text_input.split("\n") |
|
|
|
if st.button("Submit",use_container_width=True): |
|
result, warnings = fill_mask(input_sentences) |
|
st.session_state['warnings'] = warnings |
|
|
|
if option_selected == 'Upload a file(csv/txt)': |
|
|
|
uploaded_file = st.file_uploader("Choose a file-(one sentence per line)") |
|
if uploaded_file is not None: |
|
|
|
stringio = StringIO(uploaded_file.getvalue().decode("utf-8")) |
|
string_data = stringio.read() |
|
|
|
input_sentences = string_data.split("\n") |
|
|
|
if st.button("Submit",use_container_width=True): |
|
result, warnings = fill_mask(input_sentences) |
|
st.session_state['warnings'] = warnings |
|
|
|
if st.session_state['warnings']: |
|
for warning in st.session_state['warnings']: |
|
st.warning(warning) |
|
|
|
st.markdown("Example") |
|
st.code(sample_sentence, wrap_lines=True) |
|
if st.button("Test Example",use_container_width=True): |
|
result, warnings = fill_mask(sample_sentence.split("\n")) |
|
|
|
with col2: |
|
with st.container(border=True): |
|
st.markdown("Output :bar_chart:") |
|
if 'result' in locals() and result: |
|
if len(result) == 1: |
|
for sentence, predictions in result.items(): |
|
for prediction in predictions: |
|
predicted_word = prediction['token_str'] |
|
score = prediction['score'] * 100 |
|
|
|
st.markdown(f""" |
|
<div class="bar"> |
|
<div class="bar-fill" style="width: {score}%;"></div> |
|
</div> |
|
<div class="container"> |
|
<div style="align-items: left;">{predicted_word}</div> |
|
<div style="align-items: center;">{score:.2f}%</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
else: |
|
index = 0 |
|
for sentence, predictions in result.items(): |
|
index += 1 |
|
if predictions: |
|
top_prediction = predictions[0] |
|
predicted_word = top_prediction['token_str'] |
|
score = top_prediction['score'] * 100 |
|
|
|
st.markdown(f""" |
|
<div class="bar"> |
|
<div class="bar-fill" style="width: {score}%;"></div> |
|
</div> |
|
<div class="container"> |
|
<div style="align-items: left;">{predicted_word} (line {index})</div> |
|
<div style="align-items: right;">{score:.2f}%</div> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
|
|
if 'result' in locals(): |
|
if result: |
|
line = 0 |
|
for sentence, predictions in result.items(): |
|
line += 1 |
|
predicted_word = predictions[0]['token_str'] |
|
full_sentence = replace_mask(sentence, predicted_word) |
|
st.write(f"**Sentence {line}:** {full_sentence }") |
|
|
|
css = """ |
|
<style> |
|
footer {display:none !important;} |
|
|
|
.gr-button-primary { |
|
z-index: 14; |
|
height: 43px; |
|
width: 130px; |
|
left: 0px; |
|
top: 0px; |
|
padding: 0px; |
|
cursor: pointer !important; |
|
background: none rgb(17, 20, 45) !important; |
|
border: none !important; |
|
text-align: center !important; |
|
font-family: Poppins !important; |
|
font-size: 14px !important; |
|
font-weight: 500 !important; |
|
color: rgb(255, 255, 255) !important; |
|
line-height: 1 !important; |
|
border-radius: 12px !important; |
|
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; |
|
box-shadow: none !important; |
|
} |
|
.gr-button-primary:hover{ |
|
z-index: 14; |
|
height: 43px; |
|
width: 130px; |
|
left: 0px; |
|
top: 0px; |
|
padding: 0px; |
|
cursor: pointer !important; |
|
background: none rgb(66, 133, 244) !important; |
|
border: none !important; |
|
text-align: center !important; |
|
font-family: Poppins !important; |
|
font-size: 14px !important; |
|
font-weight: 500 !important; |
|
color: rgb(255, 255, 255) !important; |
|
line-height: 1 !important; |
|
border-radius: 12px !important; |
|
transition: box-shadow 200ms ease 0s, background 200ms ease 0s !important; |
|
box-shadow: rgb(0 0 0 / 23%) 0px 1px 7px 0px !important; |
|
} |
|
.hover\:bg-orange-50:hover { |
|
--tw-bg-opacity: 1 !important; |
|
background-color: rgb(229,225,255) !important; |
|
} |
|
.to-orange-200 { |
|
--tw-gradient-to: rgb(37 56 133 / 37%) !important; |
|
} |
|
.from-orange-400 { |
|
--tw-gradient-from: rgb(17, 20, 45) !important; |
|
--tw-gradient-to: rgb(255 150 51 / 0); |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; |
|
} |
|
.group-hover\:from-orange-500{ |
|
--tw-gradient-from:rgb(17, 20, 45) !important; |
|
--tw-gradient-to: rgb(37 56 133 / 37%); |
|
--tw-gradient-stops: var(--tw-gradient-from), var(--tw-gradient-to) !important; |
|
} |
|
.group:hover .group-hover\:text-orange-500{ |
|
--tw-text-opacity: 1 !important; |
|
color:rgb(37 56 133 / var(--tw-text-opacity)) !important; |
|
} |
|
|
|
.container { |
|
display: flex; |
|
justify-content: space-between; |
|
align-items: center; |
|
margin-bottom: 5px; |
|
width: 100%; |
|
} |
|
.bar { |
|
# width: 70%; |
|
background-color: #e6e6e6; |
|
border-radius: 12px; |
|
overflow: hidden; |
|
margin-right: 10px; |
|
height: 5px; |
|
} |
|
.bar-fill { |
|
background-color: #17152e; |
|
height: 100%; |
|
border-radius: 12px; |
|
} |
|
|
|
</style> |
|
""" |
|
|
|
st.markdown(css, unsafe_allow_html=True) |