Spaces:
Build error
Build error
danseith
commited on
Commit
·
c16370c
1
Parent(s):
2ce1788
Added edit slider and changed sampling back to multinomial.
Browse files
app.py
CHANGED
@@ -6,12 +6,16 @@ from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
|
|
6 |
from transformers import AutoModelForMaskedLM
|
7 |
|
8 |
# unmasker = pipeline("temp-scale", model="anferico/bert-for-patents")
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
|
13 |
def add_mask(text, size=1):
|
14 |
split_text = text.split()
|
|
|
|
|
|
|
|
|
15 |
idx = np.random.randint(len(split_text), size=size)
|
16 |
for i in idx:
|
17 |
split_text[i] = '[MASK]'
|
@@ -114,32 +118,38 @@ PIPELINE_REGISTRY.register_pipeline(
|
|
114 |
)
|
115 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
116 |
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
textbox = gr.Textbox(label="Type language here", lines=5)
|
135 |
-
textbox2 = gr.Textbox(placeholder="
|
136 |
-
temp_slider = gr.Slider(1.0, 2.0, value=1.0, label='
|
|
|
137 |
|
138 |
demo = gr.Interface(
|
139 |
fn=unmask,
|
140 |
-
inputs=[textbox, temp_slider],
|
141 |
-
outputs=[
|
142 |
-
examples=
|
143 |
)
|
144 |
|
145 |
demo.launch()
|
|
|
6 |
from transformers import AutoModelForMaskedLM
|
7 |
|
8 |
# unmasker = pipeline("temp-scale", model="anferico/bert-for-patents")
|
9 |
+
examples = [['A crustless [MASK] made from two slices of baked bread.', 1.2],
|
10 |
+
['The invention provides a method for altering or modifying [MASK] of one or more gene products.', 1.1],
|
11 |
+
['The graphite [MASK] is composed of a two-dimensional hexagonal lattice of carbon atoms.', 1.4]]
|
12 |
|
13 |
def add_mask(text, size=1):
|
14 |
split_text = text.split()
|
15 |
+
|
16 |
+
# If the user supplies a mask, don't add more
|
17 |
+
if '[MASK]' in split_text:
|
18 |
+
return text
|
19 |
idx = np.random.randint(len(split_text), size=size)
|
20 |
for i in idx:
|
21 |
split_text[i] = '[MASK]'
|
|
|
118 |
)
|
119 |
scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
|
120 |
|
121 |
+
|
122 |
+
def unmask(text, temp, rounds):
|
123 |
+
sampling = 'multi'
|
124 |
+
|
125 |
+
for _ in range(rounds):
|
126 |
+
text = add_mask(text, size=1)
|
127 |
+
split_text = text.split()
|
128 |
+
res = scrambler(text, temp=temp, top_k=10)
|
129 |
+
mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
|
130 |
+
out = {item["token_str"]: item["score"] for item in res}
|
131 |
+
score_to_str = {out[k]:k for k in out.keys()}
|
132 |
+
score_list = list(score_to_str.keys())
|
133 |
+
if sampling == 'multi':
|
134 |
+
idx = np.argmax(np.random.multinomial(1, score_list, 1))
|
135 |
+
else:
|
136 |
+
idx = np.random.randint(0, len(score_list))
|
137 |
+
score = score_list[idx]
|
138 |
+
new_token = score_to_str[score]
|
139 |
+
split_text[mask_pos] = new_token
|
140 |
+
text = ' '.join(split_text)
|
141 |
+
return text
|
142 |
|
143 |
textbox = gr.Textbox(label="Type language here", lines=5)
|
144 |
+
textbox2 = gr.Textbox(placeholder="", lines=4)
|
145 |
+
temp_slider = gr.Slider(1.0, 2.0, value=1.0, label='Creativity')
|
146 |
+
edit_slider = gr.Slider(1, 50, step=1, value=1.0, label='Number of edits')
|
147 |
|
148 |
demo = gr.Interface(
|
149 |
fn=unmask,
|
150 |
+
inputs=[textbox, temp_slider, edit_slider],
|
151 |
+
outputs=[textbox2],
|
152 |
+
examples=examples,
|
153 |
)
|
154 |
|
155 |
demo.launch()
|