File size: 5,010 Bytes
3411193 1f96fb2 3411193 2df020e 3411193 2df020e 3411193 8619f75 2df020e 400e02d 3411193 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import streamlit as st
from transformers import pipeline
from transformers.tokenization_utils import TruncationStrategy
import tokenizers
import pandas as pd
import requests
st.set_page_config(
page_title='AlephBERT Demo',
page_icon="🥙",
initial_sidebar_state="expanded",
)
models = {
"AlephBERT-base": {
"name_or_path":"onlplab/alephbert-base",
"description":"AlephBERT base model",
},
"HeBERT-base-TAU": {
"name_or_path":"avichr/heBERT",
"description":"HeBERT model created by TAU"
},
"mBERT-base-multilingual-cased": {
"name_or_path":"bert-base-multilingual-cased",
"description":"Multilingual BERT model"
}
}
@st.cache(show_spinner=False)
def get_json_from_url(url):
return models
return requests.get(url).json()
# models = get_json_from_url('https://huggingface.co/spaces/biu-nlp/AlephBERT/raw/main/models.json')
@st.cache(show_spinner=False, hash_funcs={tokenizers.Tokenizer: str})
def load_model(model):
pipe = pipeline('fill-mask', models[model]['name_or_path'])
def do_tokenize(inputs):
return pipe.tokenizer(
inputs,
add_special_tokens=True,
return_tensors=pipe.framework,
padding=True,
truncation=TruncationStrategy.DO_NOT_TRUNCATE,
)
def _parse_and_tokenize(
inputs, tokenized=False, **kwargs
):
if not tokenized:
inputs = do_tokenize(inputs)
return inputs
pipe._parse_and_tokenize = _parse_and_tokenize
return pipe, do_tokenize
st.title('AlephBERT🥙')
st.sidebar.markdown(
"""<div><a target="_blank" href="https://nlp.biu.ac.il/~rtsarfaty/onlp#"><img src="https://nlp.biu.ac.il/~rtsarfaty/static/landing_static/img/onlp_logo.png" style="filter: invert(100%);display: block;margin-left: auto;margin-right: auto;
width: 70%;"></a>
<p style="color:white; font-size:13px; font-family:monospace; text-align: center">AlephBERT Demo • <a href="https://nlp.biu.ac.il/~rtsarfaty/onlp#" style="text-decoration: none;color: white;" target="_blank">ONLP Lab</a></p></div>
<br>""",
unsafe_allow_html=True,
)
mode = 'Models'
if mode == 'Models':
model = st.sidebar.selectbox(
'Select Model',
list(models.keys()))
masking_level = st.sidebar.selectbox('Masking Level:', ['Tokens', 'SubWords'])
n_res = st.sidebar.number_input(
'Number Of Results',
format='%d',
value=5,
min_value=1,
max_value=100)
model_tags = model.split('-')
model_tags[0] = 'Model:' + model_tags[0]
st.markdown(''.join([f'<span style="color:white; font-size:13px; font-family:monospace; background-color: #f63766;margin:3px;padding:8px;border-radius: 5px;">{tag}</span>' for tag in model_tags]),unsafe_allow_html=True)
st.markdown('___')
unmasker, tokenize = load_model(model)
input_text = st.text_input('Insert text you want to mask', '')
if input_text:
input_masked = None
tokenized = tokenize(input_text)
ids = tokenized['input_ids'].tolist()[0]
subwords = unmasker.tokenizer.convert_ids_to_tokens(ids)
if masking_level == 'Tokens':
tokens = str(input_text).split()
mask_idx = st.selectbox('Select token to mask:', [None] + list(range(len(tokens))), format_func=lambda i: tokens[i] if i else '')
if mask_idx is not None:
input_masked = ' '.join(token if i != mask_idx else '[MASK]' for i, token in enumerate(tokens))
display_input = input_masked
if masking_level == 'SubWords':
tokens = subwords
idx = st.selectbox('Select token to mask:', list(range(0,len(tokens)-1)), format_func=lambda i: tokens[i] if i else '')
tokenized['input_ids'][0][idx] = unmasker.tokenizer.mask_token_id
ids = tokenized['input_ids'].tolist()[0]
display_input = ' '.join(unmasker.tokenizer.convert_ids_to_tokens(ids[1:-1]))
if idx:
input_masked = tokenized
if input_masked:
st.markdown('#### Input:')
ids = tokenized['input_ids'].tolist()[0]
subwords = unmasker.tokenizer.convert_ids_to_tokens(ids)
st.markdown(f'<p dir="rtl">{display_input}</p>',
unsafe_allow_html=True,
)
st.markdown('#### Outputs:')
with st.spinner(f'Running {model_tags[0]} (may take a minute)...'):
res = unmasker(input_masked, tokenized=masking_level == 'SubWords', top_k=n_res)
if res:
res = [{'Prediction':r['token_str'], 'Completed Sentence':r['sequence'].replace('[SEP]', '').replace('[CLS]', ''), 'Score':r['score']} for r in res]
res_table = pd.DataFrame(res)
st.table(res_table)
|