href / app.py
Shane
updated citations
c96dbc6
raw
history blame
8.95 kB
import gradio as gr
import os
from huggingface_hub import HfApi, snapshot_download
from apscheduler.schedulers.background import BackgroundScheduler
from datasets import load_dataset
from src.utils import load_all_data, prep_df, sort_by_category
from src.md import ABOUT_TEXT, TOP_TEXT
from src.css import custom_css
import numpy as np
api = HfApi()
COLLAB_TOKEN = os.environ.get("COLLAB_TOKEN")
evals_repo = "alrope/href_results"
eval_set_repo = "allenai/href_validation"
local_result_dir = "./results/"
def restart_space():
api.restart_space(repo_id="allenai/href", token=COLLAB_TOKEN)
print("Pulling evaluation results")
repo = snapshot_download(
local_dir=local_result_dir,
ignore_patterns=[],
repo_id=evals_repo,
use_auth_token=COLLAB_TOKEN,
tqdm_class=None,
etag_timeout=30,
repo_type="dataset",
)
href_data_greedy = prep_df(load_all_data(local_result_dir, subdir="temperature=0.0"))
href_data_nongreedy = prep_df(load_all_data(local_result_dir, subdir="temperature=1.0"))
col_types_href = ["number"] + ["markdown"] + ["number"] * int((len(href_data_greedy.columns) - 1) / 2)
col_types_href_hidden = ["number"] + ["markdown"] + ["number"] * (len(href_data_greedy.columns) - 1)
categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify', "Reasoning Over Numerical Data", "Multi-Document Synthesis", "Fact Checking or Attributed QA"]
# categories = ['Average', 'Brainstorm', 'Open QA', 'Closed QA', 'Extract', 'Generation', 'Rewrite', 'Summarize', 'Classify']
# for showing random samples
eval_set = load_dataset(eval_set_repo, use_auth_token=COLLAB_TOKEN, split="dev")
def random_sample(r: gr.Request, category):
if category is None or category == []:
sample_index = np.random.randint(0, len(eval_set) - 1)
sample = eval_set[sample_index]
else: # filter by category (can be list)
if isinstance(category, str):
category = [category]
# filter down dataset to only include the category(s)
eval_set_filtered = eval_set.filter(lambda x: x["category"] in category)
sample_index = np.random.randint(0, len(eval_set_filtered) - 1)
sample = eval_set_filtered[sample_index]
markdown_text = '\n\n'.join([f"**{key}**:\n\n{value}" for key, value in sample.items()])
return markdown_text
subsets = eval_set.unique("category")
def regex_table(dataframe, regex, selected_category, style=True):
"""
Takes a model name as a regex, then returns only the rows that has that in it.
"""
dataframe = sort_by_category(dataframe, selected_category)
# Split regex statement by comma and trim whitespace around regexes
regex_list = [x.strip() for x in regex.split(",")]
# Join the list into a single regex pattern with '|' acting as OR
combined_regex = '|'.join(regex_list)
# Filter the dataframe such that 'model' contains any of the regex patterns
data = dataframe[dataframe["Model"].str.contains(combined_regex, case=False, na=False)]
data.reset_index(drop=True, inplace=True)
if style:
# Format for different columns
format_dict = {col: "{:.1f}" for col in data.columns if col not in ['Average', 'Model', 'Rank', '95% CI']}
format_dict['Average'] = "{:.2f}"
data = data.style.format(format_dict, na_rep='').set_properties(**{'text-align': 'right'})
return data
total_models = len(regex_table(href_data_greedy.copy(), "", "Average", style=False).values)
with gr.Blocks(css=custom_css) as app:
# create tabs for the app, moving the current table to one titled "rewardbench" and the benchmark_text to a tab called "About"
with gr.Row():
with gr.Column(scale=8):
gr.Markdown(TOP_TEXT.format(str(total_models)))
with gr.Column(scale=2):
# search = gr.Textbox(label="Model Search (delimit with , )", placeholder="Regex search for a model")
# filter_button = gr.Checkbox(label="Include AI2 training runs (or type ai2 above).", interactive=True)
# img = gr.Image(value="https://private-user-images.githubusercontent.com/10695622/310698241-24ed272a-0844-451f-b414-fde57478703e.png", width=500)
gr.Markdown("""
<img src="file/src/logo.png" height="130">
""")
with gr.Tabs(elem_classes="tab-buttons") as tabs:
with gr.TabItem("🏆 HREF Leaderboard"):
with gr.Row():
search_1 = gr.Textbox(label="Model Search (delimit with , )",
# placeholder="Model Search (delimit with , )",
show_label=True)
category_selector_1 = gr.Dropdown(categories, label="Sorted By", value="Average", multiselect=False, show_label=True, elem_id="category_selector", elem_classes="category_selector_class")
with gr.Row():
# reference data
rewardbench_table_hidden = gr.Dataframe(
href_data_greedy.values,
datatype=col_types_href_hidden,
headers=href_data_greedy.columns.tolist(),
visible=False,
)
rewardbench_table = gr.Dataframe(
regex_table(href_data_greedy.copy(), "", "Average"),
datatype=col_types_href,
headers=href_data_greedy.columns.tolist(),
elem_id="href_data_greedy",
interactive=False,
height=1000,
)
# with gr.TabItem("Non-Greedy"):
# with gr.Row():
# search_2 = gr.Textbox(label="Model Search (delimit with , )",
# # placeholder="Model Search (delimit with , )",
# show_label=True)
# category_selector_2 = gr.Dropdown(categories, label="Sorted By", value="Average",
# multiselect=False, show_label=True, elem_id="category_selector")
# with gr.Row():
# # reference data
# rewardbench_table_hidden_nongreedy = gr.Dataframe(
# href_data_nongreedy.values,
# datatype=col_types_href_hidden,
# headers=href_data_nongreedy.columns.tolist(),
# visible=False,
# )
# rewardbench_table_nongreedy = gr.Dataframe(
# regex_table(href_data_nongreedy.copy(), "", "Average"),
# datatype=col_types_href,
# headers=href_data_nongreedy.columns.tolist(),
# elem_id="href_data_nongreedy",
# interactive=False,
# height=1000,
# )
with gr.TabItem("About"):
with gr.Row():
gr.Markdown(ABOUT_TEXT)
with gr.TabItem("Dataset Viewer"):
with gr.Row():
# loads one sample
gr.Markdown("""## Random Dataset Sample Viewer""")
subset_selector = gr.Dropdown(subsets, label="Category", value=None, multiselect=True)
button = gr.Button("Show Random Sample")
with gr.Row():
sample_display = gr.Markdown("{sampled data loads here}")
button.click(fn=random_sample, inputs=[subset_selector], outputs=[sample_display])
search_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)
category_selector_1.change(regex_table, inputs=[rewardbench_table_hidden, search_1, category_selector_1], outputs=rewardbench_table)
# search_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)
# category_selector_2.change(regex_table, inputs=[rewardbench_table_hidden_nongreedy, search_2, category_selector_2], outputs=rewardbench_table_nongreedy)
with gr.Row():
with gr.Accordion("📚 Citation", open=False):
citation_button = gr.Textbox(
value=r"""@article{lyu2024href,
title={HREF: Human Response-Guided Evaluation of Instruction Following in Language Models},
author={Xinxi Lyu and Yizhong Wang and Hannaneh Hajishirzi and Pradeep Dasigi},
journal={arXiv preprint arXiv:2412.15524},
year={2024}
}""",
lines=7,
label="Copy the following to cite these results.",
elem_id="citation-button",
show_copy_button=True,
)
scheduler = BackgroundScheduler()
scheduler.add_job(restart_space, "interval", seconds=10800) # restarted every 3h
scheduler.start()
app.launch(allowed_paths=['src/']) # had .queue() before launch before... not sure if that's necessary