href / app.py
Shane
made chagnes to UI
ca662db
raw
history blame
8.97 kB
import gradio as gr
import os
from huggingface_hub import HfApi, snapshot_download
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from src.utils import load_all_data, prep_df, sort_by_category
from src.md import ABOUT_TEXT, TOP_TEXT
from src.css import custom_css
import numpy as np
api = HfApi()
COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN")
evals_repo = "alrope/href_results"
eval_set_repo = "allenai/href_validation"
local_result_dir = "./results/"
def restart_space():
api.restart_space(repo_id="allenai/href", token=COLLAB_TOKEN)
print("Pulling evaluation results")
repo = snapshot_download(
local_dir=local_result_dir,
ignore_patterns=[],
repo_id=evals_repo,
use_auth_token=COLLAB_TOKEN,
tqdm_class=None,
etag_timeout=30,
repo_type="dataset",
)
href_data_greedy = prep_df(load_all_data(local_result_dir, subdir="temperature=0.0"))
href_data_nongreedy = prep_df(load_all_data(local_result_dir, subdir="temperature=1.0"))
col_types_href = ["number"] + ["markdown"] + ["number"] * int((len(href_data_greedy.columns) - 1) / 2)
col_types_href_hidden = ["number"] + ["markdown"] + ["number"] * (len(href_data_greedy.columns) - 1)
categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify', "Reasoning Over Numerical Data", "Multi-Document Synthesis", "Fact Checking or Attributed QA"]
# categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify']
# for showing random samples
eval_set = load_dataset(eval_set_repo, use_auth_token=COLLAB_TOKEN, split="dev")
def random_sample(r: gr.Request, category):
if category is None or category == []:
sample_index = np.random.randint(0, len(eval_set) - 1)
sample = eval_set[sample_index]
else: # filter by category (can be list)
if isinstance(category, str):
category = [category]
# filter down dataset to only include the category(s)
eval_set_filtered = eval_set.filter(lambda x: x["category"] in category)
sample_index = np.random.randint(0, len(eval_set_filtered) - 1)
sample = eval_set_filtered[sample_index]
markdown_text = '\n\n'.join([f"**{key}**:\n\n{value}" for key, value in sample.items()])
return markdown_text
subsets = eval_set.unique("category")
def regex_table(dataframe, regex, selected_category, style=True):
"""
Takes a model name as a regex, then returns only the rows that has that in it.
"""
dataframe = sort_by_category(dataframe, selected_category)
# Split regex statement by comma and trim whitespace around regexes
regex_list = [x.strip() for x in regex.split(",")]
# Join the list into a single regex pattern with '|' acting as OR
combined_regex = '|'.join(regex_list)
# Filter the dataframe such that 'model' contains any of the regex patterns
data = dataframe[dataframe["Model"].str.contains(combined_regex, case=False, na=False)]
data.reset_index(drop=True, inplace=True)
if style:
# Format for different columns
format_dict = {col: "{:.1f}" for col in data.columns if col not in ['Average', 'Model', 'Rank', '95% CI']}
format_dict['Average'] = "{:.2f}"
data = data.style.format(format_dict, na_rep='').set_properties(**{'text-align': 'right'})
return data
total_models = len(regex_table(href_data_greedy.copy(), "", "Average", style=False).values)
with gr.Blocks(css=custom_css) as app:
# create tabs for the app, moving the current table to one titled "rewardbench" and the benchmark_text to a tab called "About"
with gr.Row():
with gr.Column(scale=8):
gr.Markdown(TOP_TEXT.format(str(total_models)))
with gr.Column(scale=2):
# search = gr.Textbox(label="Model Search (delimit with , )", placeholder="Regex search for a model")
# filter_button = gr.Checkbox(label="Include AI2 training runs (or type ai2 above).", interactive=True)
# img = gr.Image(value="https://private-user-images.githubusercontent.com/10695622/310698241-24ed272a-0844-451f-b414-fde57478703e.png", width=500)
gr.Markdown("""
<img src="file/src/logo.png" height="130">
""")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏆 HREF Leaderboard"):
with gr.Row():
search_1 = gr.Textbox(label="Model Search (delimit with , )",
# placeholder="Model Search (delimit with , )",
show_label=True)
category_selector_1 = gr.Dropdown(categories, label="Sorted By", value="Average", multiselect=False, show_label=True)
with gr.Row():
# reference data
rewardbench_table_hidden = gr.Dataframe(
href_data_greedy.values,
datatype=col_types_href_hidden,
headers=href_data_greedy.columns.tolist(),
visible=False,
)
rewardbench_table = gr.Dataframe(
regex_table(href_data_greedy.copy(), "", "Average"),
datatype=col_types_href,
headers=href_data_greedy.columns.tolist(),
elem_id="href_data_greedy",
interactive=False,
height=1000,
)
with gr.TabItem("Non-Greedy"):
with gr.Row():
search_2 = gr.Textbox(label="Model Search (delimit with , )",
# placeholder="Model Search (delimit with , )",
show_label=True)
category_selector_2 = gr.Dropdown(categories, label="Sorted By", value="Average",
multiselect=False, show_label=True, elem_id="category_selector")
with gr.Row():
# reference data
rewardbench_table_hidden_nongreedy = gr.Dataframe(
href_data_nongreedy.values,
datatype=col_types_href_hidden,
headers=href_data_nongreedy.columns.tolist(),
visible=False,
)
rewardbench_table_nongreedy = gr.Dataframe(
regex_table(href_data_nongreedy.copy(), "", "Average"),
datatype=col_types_href,
headers=href_data_nongreedy.columns.tolist(),
elem_id="href_data_nongreedy",
interactive=False,
height=1000,
)
with gr.TabItem("About"):
with gr.Row():
gr.Markdown(ABOUT_TEXT)
with gr.TabItem("Dataset Viewer"):
with gr.Row():
# loads one sample
gr.Markdown("""## Random Dataset Sample Viewer""")
subset_selector = gr.Dropdown(subsets, label="Category", value=None, multiselect=True)
button = gr.Button("Show Random Sample")
with gr.Row():
sample_display = gr.Markdown("{sampled data loads here}")
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
search_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)
category_selector_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)
search_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)
category_selector_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)
with gr.Row():
with gr.Accordion("📚 Citation", open=False):
citation_button = gr.Textbox(
value=r"""@misc{RewardBench,
title={RewardBench: Evaluating Reward Models for Language Modeling},
author={Lambert, Nathan and Pyatkin, Valentina and Morrison, Jacob and Miranda, LJ and Lin, Bill Yuchen and Chandu, Khyathi and Dziri, Nouha and Kumar, Sachin and Zick, Tom and Choi, Yejin and Smith, Noah A. and Hajishirzi, Hannaneh},
year={2024},
howpublished={\url{https://huggingface.co/spaces/allenai/reward-bench}
}""",
lines=7,
label="Copy the following to cite these results.",
elem_id="citation-button",
show_copy_button=True,
)
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=10800) # restarted every 3h
scheduler.start()
app.launch(allowed_paths=['src/']) # had .queue() before launch before... not sure if that's necessary