natolambert commited on
Commit
06fd8bd
·
1 Parent(s): faa2dab
Files changed (2) hide show
  1. app.py +43 -35
  2. src/utils.py +17 -13
app.py CHANGED
@@ -130,7 +130,7 @@ def random_sample(r: gr.Request, subset):
130
 
131
  subsets = eval_set.unique("subset")
132
 
133
- def regex_table(dataframe, regex):
134
  """
135
  Takes a model name as a regex, then returns only the rows that has that in it.
136
  """
@@ -138,6 +138,9 @@ def regex_table(dataframe, regex):
138
  regex_list = [x.strip() for x in regex.split(",")]
139
  # Join the list into a single regex pattern with '|' acting as OR
140
  combined_regex = '|'.join(regex_list)
 
 
 
141
  # Filter the dataframe such that 'model' contains any of the regex patterns
142
  return dataframe[dataframe["model"].str.contains(combined_regex, case=False, na=False)]
143
 
@@ -145,50 +148,47 @@ def regex_table(dataframe, regex):
145
  with gr.Blocks() as app:
146
  # create tabs for the app, moving the current table to one titled "HERM" and the benchmark_text to a tab called "About"
147
  with gr.Row():
148
- gr.Markdown(TOP_TEXT)
149
- search = gr.Textbox(label="Model Search (delimit with , )", placeholder="Regex search for a model")
 
 
 
150
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
151
  with gr.TabItem("HERM Eval Set - Overview"):
152
  with gr.Row():
153
- herm_table = gr.Dataframe(
 
154
  herm_data_avg.values,
155
  datatype=col_types_herm_avg,
156
  headers=herm_data_avg.columns.tolist(),
157
- elem_id="herm_dataframe_avg",
158
- height=1000,
159
  )
160
- # backup reference data
161
- herm_table_hidden = gr.Dataframe(
162
- herm_data_avg.values,
163
  datatype=col_types_herm_avg,
164
  headers=herm_data_avg.columns.tolist(),
165
- visible=False,
 
166
  )
 
167
  with gr.TabItem("HERM Eval Set - Detailed"):
168
  with gr.Row():
169
- herm_table_detailed = gr.Dataframe(
 
170
  herm_data.values,
171
  datatype=col_types_herm,
172
  headers=herm_data.columns.tolist(),
173
- elem_id="herm_dataframe",
174
- height=1000,
175
  )
176
- # backup
177
- herm_table_detailed_hidden = gr.Dataframe(
178
- herm_data.values,
179
  datatype=col_types_herm,
180
  headers=herm_data.columns.tolist(),
181
- visible=False,
 
182
  )
183
  with gr.TabItem("HERM Eval Set - Length Bias"):
184
  with gr.Row():
185
- herm_table_len = gr.Dataframe(
186
- herm_data_length.values,
187
- datatype=cols_herm_data_length,
188
- headers=herm_data_length.columns.tolist(),
189
- elem_id="herm_dataframe_length",
190
- height=1000,
191
- )
192
  # backup
193
  herm_table_len_hidden = gr.Dataframe(
194
  herm_data_length.values,
@@ -196,6 +196,13 @@ with gr.Blocks() as app:
196
  headers=herm_data_length.columns.tolist(),
197
  visible=False,
198
  )
 
 
 
 
 
 
 
199
  with gr.TabItem("Known Pref. Sets"):
200
  with gr.Row():
201
  PREF_SET_TEXT = """
@@ -203,13 +210,6 @@ with gr.Blocks() as app:
203
  """
204
  gr.Markdown(PREF_SET_TEXT)
205
  with gr.Row():
206
- pref_sets_table = gr.Dataframe(
207
- prefs_data.values,
208
- datatype=col_types_prefs,
209
- headers=prefs_data.columns.tolist(),
210
- elem_id="prefs_dataframe",
211
- height=1000,
212
- )
213
  # backup
214
  pref_sets_table_hidden = gr.Dataframe(
215
  prefs_data.values,
@@ -217,6 +217,14 @@ with gr.Blocks() as app:
217
  headers=prefs_data.columns.tolist(),
218
  visible=False,
219
  )
 
 
 
 
 
 
 
 
220
 
221
  with gr.TabItem("About"):
222
  with gr.Row():
@@ -239,10 +247,10 @@ with gr.Blocks() as app:
239
  # plot = plot_avg_correlation(herm_data_avg, prefs_data)
240
  # gr.Plot(plot)
241
 
242
- search.change(regex_table, inputs=[herm_table_hidden, search], outputs=herm_table)
243
- search.change(regex_table, inputs=[herm_table_detailed_hidden, search], outputs=herm_table_detailed)
244
- search.change(regex_table, inputs=[herm_table_len_hidden, search], outputs=herm_table_len)
245
- search.change(regex_table, inputs=[pref_sets_table_hidden, search], outputs=pref_sets_table)
246
 
247
  # Load data when app starts, TODO make this used somewhere...
248
  # def load_data_on_start():
 
130
 
131
  subsets = eval_set.unique("subset")
132
 
133
+ def regex_table(dataframe, regex, filter_button):
134
  """
135
  Takes a model name as a regex, then returns only the rows that has that in it.
136
  """
 
138
  regex_list = [x.strip() for x in regex.split(",")]
139
  # Join the list into a single regex pattern with '|' acting as OR
140
  combined_regex = '|'.join(regex_list)
141
+ # if filter_button, remove all rows with "ai2" in the model name
142
+ if (not filter_button) and ("ai2" not in regex):
143
+ dataframe = dataframe[~dataframe["model"].str.contains("ai2", case=False, na=False)]
144
  # Filter the dataframe such that 'model' contains any of the regex patterns
145
  return dataframe[dataframe["model"].str.contains(combined_regex, case=False, na=False)]
146
 
 
148
  with gr.Blocks() as app:
149
  # create tabs for the app, moving the current table to one titled "HERM" and the benchmark_text to a tab called "About"
150
  with gr.Row():
151
+ with gr.Column(scale=3):
152
+ gr.Markdown(TOP_TEXT)
153
+ with gr.Column(scale=2):
154
+ search = gr.Textbox(label="Model Search (delimit with , )", placeholder="Regex search for a model")
155
+ filter_button = gr.Checkbox(label="Include AI2 training runs (or type ai2 above).", interactive=True)
156
  with gr.Tabs(elem_classes="tab-buttons") as tabs:
157
  with gr.TabItem("HERM Eval Set - Overview"):
158
  with gr.Row():
159
+ # reference data
160
+ herm_table_hidden = gr.Dataframe(
161
  herm_data_avg.values,
162
  datatype=col_types_herm_avg,
163
  headers=herm_data_avg.columns.tolist(),
164
+ visible=False,
 
165
  )
166
+ herm_table = gr.Dataframe(
167
+ regex_table(herm_data_avg.copy(), "", False).values,
 
168
  datatype=col_types_herm_avg,
169
  headers=herm_data_avg.columns.tolist(),
170
+ elem_id="herm_dataframe_avg",
171
+ height=1000,
172
  )
173
+
174
  with gr.TabItem("HERM Eval Set - Detailed"):
175
  with gr.Row():
176
+ # ref data
177
+ herm_table_detailed_hidden = gr.Dataframe(
178
  herm_data.values,
179
  datatype=col_types_herm,
180
  headers=herm_data.columns.tolist(),
181
+ visible=False,
 
182
  )
183
+ herm_table_detailed = gr.Dataframe(
184
+ regex_table(herm_data.copy(), "", False).values,
 
185
  datatype=col_types_herm,
186
  headers=herm_data.columns.tolist(),
187
+ elem_id="herm_dataframe",
188
+ height=1000,
189
  )
190
  with gr.TabItem("HERM Eval Set - Length Bias"):
191
  with gr.Row():
 
 
 
 
 
 
 
192
  # backup
193
  herm_table_len_hidden = gr.Dataframe(
194
  herm_data_length.values,
 
196
  headers=herm_data_length.columns.tolist(),
197
  visible=False,
198
  )
199
+ herm_table_len = gr.Dataframe(
200
+ regex_table(herm_data_length.copy(), "", False).values,
201
+ datatype=cols_herm_data_length,
202
+ headers=herm_data_length.columns.tolist(),
203
+ elem_id="herm_dataframe_length",
204
+ height=1000,
205
+ )
206
  with gr.TabItem("Known Pref. Sets"):
207
  with gr.Row():
208
  PREF_SET_TEXT = """
 
210
  """
211
  gr.Markdown(PREF_SET_TEXT)
212
  with gr.Row():
 
 
 
 
 
 
 
213
  # backup
214
  pref_sets_table_hidden = gr.Dataframe(
215
  prefs_data.values,
 
217
  headers=prefs_data.columns.tolist(),
218
  visible=False,
219
  )
220
+ pref_sets_table = gr.Dataframe(
221
+ regex_table(prefs_data.copy(), "", False).values,
222
+ datatype=col_types_prefs,
223
+ headers=prefs_data.columns.tolist(),
224
+ elem_id="prefs_dataframe",
225
+ height=1000,
226
+ )
227
+
228
 
229
  with gr.TabItem("About"):
230
  with gr.Row():
 
247
  # plot = plot_avg_correlation(herm_data_avg, prefs_data)
248
  # gr.Plot(plot)
249
 
250
+ search.change(regex_table, inputs=[herm_table_hidden, search, filter_button], outputs=herm_table)
251
+ search.change(regex_table, inputs=[herm_table_detailed_hidden, search, filter_button], outputs=herm_table_detailed)
252
+ search.change(regex_table, inputs=[herm_table_len_hidden, search, filter_button], outputs=herm_table_len)
253
+ search.change(regex_table, inputs=[pref_sets_table_hidden, search, filter_button], outputs=pref_sets_table)
254
 
255
  # Load data when app starts, TODO make this used somewhere...
256
  # def load_data_on_start():
src/utils.py CHANGED
@@ -72,6 +72,23 @@ def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to p
72
  cols.remove("model_beaker")
73
  df = df.drop(columns=["model_beaker"])
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # round
76
  df[cols] = df[cols].round(2)
77
  avg = np.nanmean(df[cols].values,axis=1).round(2)
@@ -92,17 +109,4 @@ def load_all_data(data_repo, subdir:str, subsubsets=False): # use HF api to p
92
  cols.insert(1, cols.pop(cols.index('model_type')))
93
  df = df.loc[:, cols]
94
 
95
- # remove column xstest (outdated data)
96
- # if xstest is a column
97
- if "xstest" in df.columns:
98
- df = df.drop(columns=["xstest"])
99
-
100
- if "ref_model" in df.columns:
101
- df = df.drop(columns=["ref_model"])
102
-
103
- # remove column anthropic and summarize_prompted (outdated data)
104
- if "anthropic" in df.columns:
105
- df = df.drop(columns=["anthropic"])
106
- if "summarize_prompted" in df.columns:
107
- df = df.drop(columns=["summarize_prompted"])
108
  return df
 
72
  cols.remove("model_beaker")
73
  df = df.drop(columns=["model_beaker"])
74
 
75
+ # remove column xstest (outdated data)
76
+ # if xstest is a column
77
+ if "xstest" in cols:
78
+ df = df.drop(columns=["xstest"])
79
+ cols.remove("xstest")
80
+
81
+ if "ref_model" in df.columns:
82
+ df = df.drop(columns=["ref_model"])
83
+
84
+ # remove column anthropic and summarize_prompted (outdated data)
85
+ if "anthropic" in cols:
86
+ df = df.drop(columns=["anthropic"])
87
+ cols.remove("anthropic")
88
+ if "summarize_prompted" in cols:
89
+ df = df.drop(columns=["summarize_prompted"])
90
+ cols.remove("summarize_prompted")
91
+
92
  # round
93
  df[cols] = df[cols].round(2)
94
  avg = np.nanmean(df[cols].values,axis=1).round(2)
 
109
  cols.insert(1, cols.pop(cols.index('model_type')))
110
  df = df.loc[:, cols]
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return df