danseith commited on
Commit
c16370c
·
1 Parent(s): 2ce1788

Added edit slider and changed sampling back to multinomial.

Browse files
Files changed (1) hide show
  1. app.py +34 -24
app.py CHANGED
@@ -6,12 +6,16 @@ from transformers.pipelines import PIPELINE_REGISTRY, FillMaskPipeline
6
  from transformers import AutoModelForMaskedLM
7
 
8
  # unmasker = pipeline("temp-scale", model="anferico/bert-for-patents")
9
- example = 'A crustless [MASK] made from two slices of baked bread.'
10
- example = 'The invention provides a method for altering or modifying [MASK] of one or more gene products.'
11
- example = 'The graphite [MASK] is composed of a two-dimensional hexagonal lattice of carbon atoms.'
12
 
13
  def add_mask(text, size=1):
14
  split_text = text.split()
 
 
 
 
15
  idx = np.random.randint(len(split_text), size=size)
16
  for i in idx:
17
  split_text[i] = '[MASK]'
@@ -114,32 +118,38 @@ PIPELINE_REGISTRY.register_pipeline(
114
  )
115
  scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
116
 
117
- def unmask(text, temp, sampling='uniform'):
118
- # text = add_mask(text)
119
- split_text = text.split()
120
- res = scrambler(text, temp=temp, top_k=3)
121
- mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
122
- out = {item["token_str"]: item["score"] for item in res}
123
- score_to_str = {out[k]:k for k in out.keys()}
124
- score_list = list(score_to_str.keys())
125
- if sampling == 'multi':
126
- idx = np.argmax(np.random.multinomial(1, score_list, 1))
127
- else:
128
- idx = np.random.randint(0, len(score_list))
129
- score = score_list[idx]
130
- new_token = score_to_str[score]
131
- split_text[mask_pos] = new_token
132
- return out, ' '.join(split_text)
 
 
 
 
 
133
 
134
  textbox = gr.Textbox(label="Type language here", lines=5)
135
- textbox2 = gr.Textbox(placeholder="Type here...", lines=4)
136
- temp_slider = gr.Slider(1.0, 2.0, value=1.0, label='Temperature Scale')
 
137
 
138
  demo = gr.Interface(
139
  fn=unmask,
140
- inputs=[textbox, temp_slider],
141
- outputs=["label", textbox2],
142
- examples=[[example, 1.2]],
143
  )
144
 
145
  demo.launch()
 
6
  from transformers import AutoModelForMaskedLM
7
 
8
  # unmasker = pipeline("temp-scale", model="anferico/bert-for-patents")
9
+ examples = [['A crustless [MASK] made from two slices of baked bread.', 1.2],
10
+ ['The invention provides a method for altering or modifying [MASK] of one or more gene products.', 1.1],
11
+ ['The graphite [MASK] is composed of a two-dimensional hexagonal lattice of carbon atoms.', 1.4]]
12
 
13
  def add_mask(text, size=1):
14
  split_text = text.split()
15
+
16
+ # If the user supplies a mask, don't add more
17
+ if '[MASK]' in split_text:
18
+ return text
19
  idx = np.random.randint(len(split_text), size=size)
20
  for i in idx:
21
  split_text[i] = '[MASK]'
 
118
  )
119
  scrambler = pipeline("temp-scale", model="anferico/bert-for-patents")
120
 
121
+
122
+ def unmask(text, temp, rounds):
123
+ sampling = 'multi'
124
+
125
+ for _ in range(rounds):
126
+ text = add_mask(text, size=1)
127
+ split_text = text.split()
128
+ res = scrambler(text, temp=temp, top_k=10)
129
+ mask_pos = [i for i, t in enumerate(split_text) if 'MASK' in t][0]
130
+ out = {item["token_str"]: item["score"] for item in res}
131
+ score_to_str = {out[k]:k for k in out.keys()}
132
+ score_list = list(score_to_str.keys())
133
+ if sampling == 'multi':
134
+ idx = np.argmax(np.random.multinomial(1, score_list, 1))
135
+ else:
136
+ idx = np.random.randint(0, len(score_list))
137
+ score = score_list[idx]
138
+ new_token = score_to_str[score]
139
+ split_text[mask_pos] = new_token
140
+ text = ' '.join(split_text)
141
+ return text
142
 
143
  textbox = gr.Textbox(label="Type language here", lines=5)
144
+ textbox2 = gr.Textbox(placeholder="", lines=4)
145
+ temp_slider = gr.Slider(1.0, 2.0, value=1.0, label='Creativity')
146
+ edit_slider = gr.Slider(1, 50, step=1, value=1.0, label='Number of edits')
147
 
148
  demo = gr.Interface(
149
  fn=unmask,
150
+ inputs=[textbox, temp_slider, edit_slider],
151
+ outputs=[textbox2],
152
+ examples=examples,
153
  )
154
 
155
  demo.launch()