Spaces:
Runtime error
Runtime error
Tristan Thrush
commited on
Commit
Β·
dbe89c5
0
Parent(s):
first commit
Browse files- .env.example +3 -0
- .gitignore +163 -0
- Makefile +10 -0
- README.md +70 -0
- app.py +236 -0
- collect.py +133 -0
- config.py.example +6 -0
- requirements.txt +5 -0
- utils.py +39 -0
.env.example
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
DATASET_REPO_URL="https://huggingface.co/datasets/{DATASET_ID}"
|
2 |
+
FORCE_PUSH="no"
|
3 |
+
HF_TOKEN="hf_xxx"
|
.gitignore
ADDED
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
161 |
+
|
162 |
+
# Local development
|
163 |
+
data/
|
Makefile
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.PHONY: style quality
|
2 |
+
|
3 |
+
style:
|
4 |
+
python -m black --line-length 119 --target-version py38 .
|
5 |
+
python -m isort .
|
6 |
+
|
7 |
+
quality:
|
8 |
+
python -m black --check --line-length 119 --target-version py38 .
|
9 |
+
python -m isort --check-only .
|
10 |
+
python -m flake8 --max-line-length 119 .
|
README.md
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: RLHF
|
3 |
+
emoji: π’
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: gray
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
An RLHF interface for data collection with [Amazon Mechanical Turk](https://www.mturk.com) and Gradio.
|
13 |
+
|
14 |
+
## Instructions for someone to use for their own project
|
15 |
+
|
16 |
+
### Install dependencies
|
17 |
+
|
18 |
+
First, create a Python virtual environment and install the project's dependencies as follows:
|
19 |
+
|
20 |
+
```bash
|
21 |
+
python -m pip install -r requirements.txt
|
22 |
+
```
|
23 |
+
|
24 |
+
### Setting up the Space
|
25 |
+
|
26 |
+
1. Clone this repo and deploy it on your own Hugging Face space.
|
27 |
+
2. Add the following secrets to your space:
|
28 |
+
- `HF_TOKEN`: One of your Hugging Face tokens.
|
29 |
+
- `DATASET_REPO_URL`: The url to an empty dataset that you created the hub. It
|
30 |
+
can be a private or public dataset.
|
31 |
+
- `FORCE_PUSH`: "yes"
|
32 |
+
When you run this space on mturk and when people visit your space on
|
33 |
+
huggingface.co, the app will use your token to automatically store new HITs
|
34 |
+
in your dataset. Setting `FORCE_PUSH` to "yes" ensures that your repo will
|
35 |
+
force push changes to the dataset during data collection. Otherwise,
|
36 |
+
accidental manual changes to your dataset could result in your space getting
|
37 |
+
merge conflicts as it automatically tries to push the dataset to the hub. For
|
38 |
+
local development, add these three keys to a `.env` file, and consider setting
|
39 |
+
`FORCE_PUSH` to "no".
|
40 |
+
|
41 |
+
To launch the Space locally, run:
|
42 |
+
|
43 |
+
```bash
|
44 |
+
python app.py
|
45 |
+
```
|
46 |
+
|
47 |
+
The app will then be available at a local address, such as http://127.0.0.1:7860
|
48 |
+
|
49 |
+
### Running data collection*
|
50 |
+
|
51 |
+
1. On your local repo that you pulled, create a copy of `config.py.example`,
|
52 |
+
just called `config.py`. Now, put keys from your AWS account in `config.py`.
|
53 |
+
These keys should be for an AWS account that has the
|
54 |
+
AmazonMechanicalTurkFullAccess permission. You also need to
|
55 |
+
create an mturk requestor account associated with your AWS account.
|
56 |
+
2. Run `python collect.py` locally.
|
57 |
+
|
58 |
+
### Profit
|
59 |
+
Now, you should be watching hits come into your Hugging Face dataset
|
60 |
+
automatically!
|
61 |
+
|
62 |
+
### Tips and tricks
|
63 |
+
|
64 |
+
- Use caution while doing local development of your space and
|
65 |
+
simultaneously running it on mturk. Consider setting `FORCE_PUSH` to "no" in
|
66 |
+
your local `.env` file.
|
67 |
+
- huggingface spaces have limited computational resources and memory. If you
|
68 |
+
run too many HITs and/or assignments at once, then you could encounter issues.
|
69 |
+
You could also encounter issues if you are trying to create a dataset that is
|
70 |
+
very large. Check the log of your space for any errors that could be happening.
|
app.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Basic example for doing model-in-the-loop dynamic adversarial data collection
|
2 |
+
# using Gradio Blocks.
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
import threading
|
6 |
+
import uuid
|
7 |
+
from pathlib import Path
|
8 |
+
from urllib.parse import parse_qs
|
9 |
+
from datasets import load_dataset
|
10 |
+
import gradio as gr
|
11 |
+
from dotenv import load_dotenv
|
12 |
+
from huggingface_hub import Repository
|
13 |
+
import random
|
14 |
+
|
15 |
+
from utils import force_git_push
|
16 |
+
|
17 |
+
|
18 |
+
# These variables are for storing the MTurk HITs in a Hugging Face dataset.
|
19 |
+
if Path(".env").is_file():
|
20 |
+
load_dotenv(".env")
|
21 |
+
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
|
22 |
+
FORCE_PUSH = os.getenv("FORCE_PUSH")
|
23 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
24 |
+
PROMPT_TEMPLATES = Path("prompt_templates")
|
25 |
+
|
26 |
+
DATA_FILENAME = "data.jsonl"
|
27 |
+
DATA_FILE = os.path.join("data", DATA_FILENAME)
|
28 |
+
repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN)
|
29 |
+
ds = load_dataset("HuggingFaceH4/instruction-pilot-outputs", split="train", use_auth_token=HF_TOKEN)
|
30 |
+
|
31 |
+
TOTAL_CNT = 10 # How many user inputs per HIT
|
32 |
+
|
33 |
+
# This function pushes the HIT data written in data.jsonl to our Hugging Face
|
34 |
+
# dataset every minute. Adjust the frequency to suit your needs.
|
35 |
+
PUSH_FREQUENCY = 60
|
36 |
+
|
37 |
+
|
38 |
+
def asynchronous_push(f_stop):
|
39 |
+
if repo.is_repo_clean():
|
40 |
+
print("Repo currently clean. Ignoring push_to_hub")
|
41 |
+
else:
|
42 |
+
repo.git_add(auto_lfs_track=True)
|
43 |
+
repo.git_commit("Auto commit by space")
|
44 |
+
if FORCE_PUSH == "yes":
|
45 |
+
force_git_push(repo)
|
46 |
+
else:
|
47 |
+
repo.git_push()
|
48 |
+
if not f_stop.is_set():
|
49 |
+
# call again in 60 seconds
|
50 |
+
threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start()
|
51 |
+
|
52 |
+
|
53 |
+
f_stop = threading.Event()
|
54 |
+
asynchronous_push(f_stop)
|
55 |
+
|
56 |
+
demo = gr.Blocks()
|
57 |
+
|
58 |
+
def random_sample_with_least_annotated_examples_first():
|
59 |
+
annotations = open(DATA_FILE, "r").readlines()
|
60 |
+
id_to_count = {}
|
61 |
+
for line in annotations:
|
62 |
+
annotation = json.loads(line)
|
63 |
+
# Only include annotations by actual turkers in the count.
|
64 |
+
if annotation["assignmentId"] != "":
|
65 |
+
example_id = annotation["id"]
|
66 |
+
id_to_count[example_id] = id_to_count.get(example_id, 0) + 1
|
67 |
+
ds_with_annotation_counts = ds.map(lambda example: {"annotation_count": id_to_count.get(example["id"], 0)})
|
68 |
+
ds_with_annotation_counts = ds_with_annotation_counts.shuffle()
|
69 |
+
ds_with_annotation_counts = ds_with_annotation_counts.sort("annotation_count")
|
70 |
+
example = ds_with_annotation_counts.select([0])[0]
|
71 |
+
# We only want to give the annotator 2 choices, so we sample 2 outputs without replacement.
|
72 |
+
example["outputs"] = random.sample(example["outputs"], 2)
|
73 |
+
return example
|
74 |
+
|
75 |
+
|
76 |
+
with demo:
|
77 |
+
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
|
78 |
+
|
79 |
+
initial_sample = random_sample_with_least_annotated_examples_first()
|
80 |
+
|
81 |
+
# We keep track of state as a JSON
|
82 |
+
state_dict = {
|
83 |
+
"taskId": str(uuid.uuid4()),
|
84 |
+
"assignmentId": "",
|
85 |
+
"cnt": 0,
|
86 |
+
"data": [initial_sample],
|
87 |
+
}
|
88 |
+
state = gr.JSON(state_dict, visible=False)
|
89 |
+
|
90 |
+
gr.Markdown("# Choose the most helpful and honest response.")
|
91 |
+
|
92 |
+
state_display = gr.Markdown(f"Your messages: 0/{TOTAL_CNT}")
|
93 |
+
|
94 |
+
def _select_response(selected_response, state, dummy):
|
95 |
+
if selected_response == "":
|
96 |
+
# Don't do anything if the worker didn't select anything yet.
|
97 |
+
return (
|
98 |
+
gr.update(),
|
99 |
+
gr.update(),
|
100 |
+
gr.update(),
|
101 |
+
gr.update(),
|
102 |
+
gr.update(),
|
103 |
+
gr.update(),
|
104 |
+
state,
|
105 |
+
dummy,
|
106 |
+
)
|
107 |
+
state["cnt"] += 1
|
108 |
+
state_display = f"Your messages: {state['cnt']}/{TOTAL_CNT}"
|
109 |
+
done = state["cnt"] == TOTAL_CNT
|
110 |
+
state["data"][-1]["selected_response"] = selected_response
|
111 |
+
if state["cnt"] == TOTAL_CNT:
|
112 |
+
# Write the HIT data to our local dataset because the worker has
|
113 |
+
# submitted everything now.
|
114 |
+
with open(DATA_FILE, "a") as jsonlfile:
|
115 |
+
json_data_with_assignment_id = [
|
116 |
+
json.dumps(
|
117 |
+
dict(
|
118 |
+
{"assignmentId": state["assignmentId"], "taskId": state["taskId"]},
|
119 |
+
**datum,
|
120 |
+
)
|
121 |
+
)
|
122 |
+
for datum in state["data"]
|
123 |
+
]
|
124 |
+
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
|
125 |
+
query = parse_qs(dummy[1:])
|
126 |
+
if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
|
127 |
+
# It seems that someone is using this app on mturk. We need to
|
128 |
+
# store the assignmentId in the state before submit_hit_button
|
129 |
+
# is clicked. We can do this here in _predict. We need to save the
|
130 |
+
# assignmentId so that the turker can get credit for their HIT.
|
131 |
+
state["assignmentId"] = query["assignmentId"][0]
|
132 |
+
toggle_final_submit = gr.update(visible=done)
|
133 |
+
toggle_final_submit_preview = gr.update(visible=False)
|
134 |
+
else:
|
135 |
+
toggle_final_submit_preview = gr.update(visible=done)
|
136 |
+
toggle_final_submit = gr.update(visible=False)
|
137 |
+
|
138 |
+
toggle_select_response_button = gr.update(visible=not done)
|
139 |
+
|
140 |
+
new_sample = random_sample_with_least_annotated_examples_first()
|
141 |
+
new_outputs = [obj["output"] for obj in new_sample["outputs"]]
|
142 |
+
state["data"].append(new_sample)
|
143 |
+
past_conversation = gr.update(value=new_sample["prompt"])
|
144 |
+
select_response = gr.update(choices=new_outputs, value="")
|
145 |
+
|
146 |
+
return (
|
147 |
+
past_conversation,
|
148 |
+
select_response,
|
149 |
+
toggle_select_response_button,
|
150 |
+
toggle_final_submit,
|
151 |
+
toggle_final_submit_preview,
|
152 |
+
state_display,
|
153 |
+
state,
|
154 |
+
dummy,
|
155 |
+
)
|
156 |
+
|
157 |
+
# Input fields
|
158 |
+
past_conversation = gr.Markdown(value=initial_sample["prompt"])
|
159 |
+
initial_outputs = [obj["output"] for obj in initial_sample["outputs"]]
|
160 |
+
select_response = gr.Radio(
|
161 |
+
choices=initial_outputs, label="Choose the most helpful and honest response"
|
162 |
+
)
|
163 |
+
select_response_button = gr.Button("Submit Response")
|
164 |
+
submit_hit_button = gr.Button("Submit HIT", visible=False)
|
165 |
+
submit_hit_button_preview = gr.Button(
|
166 |
+
"Submit Work (preview mode; no MTurk HIT credit, but your examples will still be stored)",
|
167 |
+
visible=False,
|
168 |
+
)
|
169 |
+
|
170 |
+
# Button event handlers
|
171 |
+
get_window_location_search_js = """
|
172 |
+
function(select_response, state, dummy) {
|
173 |
+
return [select_response, state, window.location.search];
|
174 |
+
}
|
175 |
+
"""
|
176 |
+
|
177 |
+
select_response_button.click(
|
178 |
+
_select_response,
|
179 |
+
inputs=[select_response, state, dummy],
|
180 |
+
outputs=[
|
181 |
+
past_conversation,
|
182 |
+
select_response,
|
183 |
+
select_response_button,
|
184 |
+
submit_hit_button,
|
185 |
+
submit_hit_button_preview,
|
186 |
+
state_display,
|
187 |
+
state,
|
188 |
+
dummy,
|
189 |
+
],
|
190 |
+
_js=get_window_location_search_js,
|
191 |
+
)
|
192 |
+
|
193 |
+
post_hit_js = """
|
194 |
+
function(state) {
|
195 |
+
// If there is an assignmentId, then the submitter is on mturk
|
196 |
+
// and has accepted the HIT. So, we need to submit their HIT.
|
197 |
+
const form = document.createElement('form');
|
198 |
+
form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
|
199 |
+
form.method = 'post';
|
200 |
+
for (const key in state) {
|
201 |
+
const hiddenField = document.createElement('input');
|
202 |
+
hiddenField.type = 'hidden';
|
203 |
+
hiddenField.name = key;
|
204 |
+
hiddenField.value = state[key];
|
205 |
+
form.appendChild(hiddenField);
|
206 |
+
};
|
207 |
+
document.body.appendChild(form);
|
208 |
+
form.submit();
|
209 |
+
return state;
|
210 |
+
}
|
211 |
+
"""
|
212 |
+
|
213 |
+
submit_hit_button.click(
|
214 |
+
lambda state: state,
|
215 |
+
inputs=[state],
|
216 |
+
outputs=[state],
|
217 |
+
_js=post_hit_js,
|
218 |
+
)
|
219 |
+
|
220 |
+
refresh_app_js = """
|
221 |
+
function(state) {
|
222 |
+
// The following line here loads the app again so the user can
|
223 |
+
// enter in another preview-mode "HIT".
|
224 |
+
window.location.href = window.location.href;
|
225 |
+
return state;
|
226 |
+
}
|
227 |
+
"""
|
228 |
+
|
229 |
+
submit_hit_button_preview.click(
|
230 |
+
lambda state: state,
|
231 |
+
inputs=[state],
|
232 |
+
outputs=[state],
|
233 |
+
_js=refresh_app_js,
|
234 |
+
)
|
235 |
+
|
236 |
+
demo.launch()
|
collect.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from os import path
|
3 |
+
|
4 |
+
import boto3
|
5 |
+
from boto.mturk.question import ExternalQuestion
|
6 |
+
from config import MTURK_KEY, MTURK_SECRET
|
7 |
+
|
8 |
+
parser = argparse.ArgumentParser()
|
9 |
+
parser.add_argument("--mturk_region", default="us-east-1", help="The region for mturk (default: us-east-1)")
|
10 |
+
parser.add_argument(
|
11 |
+
"--space_name",
|
12 |
+
default="Tristan/static-rlhf-interface",
|
13 |
+
help="Name of the accompanying Hugging Face space (default: huggingface/rlhf-interface)",
|
14 |
+
)
|
15 |
+
parser.add_argument("--num_hits", type=int, default=5, help="The number of HITs.")
|
16 |
+
parser.add_argument(
|
17 |
+
"--num_assignments", type=int, default=1, help="The number of times that the HIT can be accepted and completed."
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--live_mode",
|
21 |
+
action="store_true",
|
22 |
+
help="""
|
23 |
+
Whether to run in live mode with real turkers. This will charge your account money.
|
24 |
+
If you don't use this flag, the HITs will be deployed on the sandbox version of mturk,
|
25 |
+
which will not charge your account money.
|
26 |
+
""",
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--refresh_qualification_test",
|
30 |
+
action="store_true",
|
31 |
+
help="""
|
32 |
+
Whether to refresh the qualification test. If you've made edits to the test
|
33 |
+
xml files, it is necessary to do this.
|
34 |
+
""",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--custom_qualification_test",
|
38 |
+
action="store_true",
|
39 |
+
help="""
|
40 |
+
Whether to require the custom qualification test.
|
41 |
+
""",
|
42 |
+
)
|
43 |
+
parser.add_argument(
|
44 |
+
"--master_turkers",
|
45 |
+
action="store_true",
|
46 |
+
help="""
|
47 |
+
Whether to only use turkers with the master qualification.
|
48 |
+
""",
|
49 |
+
)
|
50 |
+
parser.add_argument(
|
51 |
+
"--us_turkers",
|
52 |
+
action="store_true",
|
53 |
+
help="""
|
54 |
+
Whether to only use US-based turkers.
|
55 |
+
""",
|
56 |
+
)
|
57 |
+
|
58 |
+
args = parser.parse_args()
|
59 |
+
|
60 |
+
MTURK_URL = f"https://mturk-requester{'' if args.live_mode else '-sandbox'}.{args.mturk_region}.amazonaws.com"
|
61 |
+
|
62 |
+
mturk = boto3.client(
|
63 |
+
"mturk",
|
64 |
+
aws_access_key_id=MTURK_KEY,
|
65 |
+
aws_secret_access_key=MTURK_SECRET,
|
66 |
+
region_name=args.mturk_region,
|
67 |
+
endpoint_url=MTURK_URL,
|
68 |
+
)
|
69 |
+
|
70 |
+
# This is the URL that makes the space embeddable in an mturk iframe
|
71 |
+
question = ExternalQuestion(f"https://hf.space/embed/{args.space_name}/+?__theme=light", frame_height=600)
|
72 |
+
|
73 |
+
qualification_requirements=[]
|
74 |
+
|
75 |
+
if args.master_turkers:
|
76 |
+
qualification_requirements.append({
|
77 |
+
QualificationTypeId: '2F1QJWKUDD8XADTFD2Q0G6UTO95ALH',
|
78 |
+
Comparator: 'Exists'
|
79 |
+
})
|
80 |
+
|
81 |
+
if args.us_turkers:
|
82 |
+
qualification_requirements.append({
|
83 |
+
QualificationTypeId: '00000000000000000071',
|
84 |
+
Comparator: 'In',
|
85 |
+
LocaleValues: [
|
86 |
+
{ Country: "US" },
|
87 |
+
]
|
88 |
+
})
|
89 |
+
|
90 |
+
if args.custom_qualification_test:
|
91 |
+
qualification_type_id = (
|
92 |
+
open("qualification_type_id.txt", "r").read() if path.exists("qualification_type_id.txt") else None
|
93 |
+
)
|
94 |
+
if args.refresh_qualification_test or qualification_type_id is None:
|
95 |
+
if qualification_type_id is not None:
|
96 |
+
mturk.delete_qualification_type(QualificationTypeId=qualification_type_id)
|
97 |
+
response = mturk.create_qualification_type(
|
98 |
+
Name="rlhf--qualification",
|
99 |
+
Keywords="RLHF qualification",
|
100 |
+
Description="Qualification test for RLHF task.",
|
101 |
+
QualificationTypeStatus="Active",
|
102 |
+
Test=open("qualification_questions.xml", mode="r").read(),
|
103 |
+
AnswerKey=open("qualification_answers.xml", mode="r").read(),
|
104 |
+
TestDurationInSeconds=3600,
|
105 |
+
AutoGranted=False,
|
106 |
+
)
|
107 |
+
qualification_type_id = response["QualificationType"]["QualificationTypeId"]
|
108 |
+
open("qualification_type_id.txt", "w+").write(qualification_type_id)
|
109 |
+
qualification_requirements.append({
|
110 |
+
"QualificationTypeId": qualification_type_id,
|
111 |
+
"Comparator": "Exists",
|
112 |
+
"RequiredToPreview": False,
|
113 |
+
"ActionsGuarded": "Accept",
|
114 |
+
})
|
115 |
+
|
116 |
+
for i in range(args.num_hits):
|
117 |
+
new_hit = mturk.create_hit(
|
118 |
+
Title="RLHF HIT",
|
119 |
+
Description="Interact with an AI",
|
120 |
+
Keywords="chatbot",
|
121 |
+
Reward="0.25",
|
122 |
+
MaxAssignments=args.num_assignments,
|
123 |
+
LifetimeInSeconds=172800,
|
124 |
+
AssignmentDurationInSeconds=600,
|
125 |
+
AutoApprovalDelayInSeconds=14400,
|
126 |
+
Question=question.get_as_xml(),
|
127 |
+
QualificationRequirements=qualification_requirements,
|
128 |
+
)
|
129 |
+
|
130 |
+
print(
|
131 |
+
f"HIT Group Link: https://worker{'' if args.live_mode else 'sandbox'}.mturk.com/mturk/preview?groupId="
|
132 |
+
+ new_hit["HIT"]["HITGroupId"]
|
133 |
+
)
|
config.py.example
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Fill in the information and rename this file config.py
|
2 |
+
# You can obtain the key and secret in the AWS Identity
|
3 |
+
# and Access Management (IAM) panel.
|
4 |
+
|
5 |
+
MTURK_KEY = ''
|
6 |
+
MTURK_SECRET = ''
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
boto3==1.24.32
|
2 |
+
boto==2.49.0
|
3 |
+
huggingface_hub==0.8.1
|
4 |
+
python-dotenv==0.20.0
|
5 |
+
datasets==2.9.0
|
utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
|
3 |
+
from huggingface_hub.repository import _lfs_log_progress
|
4 |
+
|
5 |
+
|
6 |
+
def force_git_push(
|
7 |
+
repo,
|
8 |
+
):
|
9 |
+
"""
|
10 |
+
force a simple git push
|
11 |
+
Blocking. Will return url to commit on remote
|
12 |
+
repo.
|
13 |
+
"""
|
14 |
+
command = "git push --force"
|
15 |
+
|
16 |
+
try:
|
17 |
+
with _lfs_log_progress():
|
18 |
+
process = subprocess.Popen(
|
19 |
+
command.split(),
|
20 |
+
stderr=subprocess.PIPE,
|
21 |
+
stdout=subprocess.PIPE,
|
22 |
+
encoding="utf-8",
|
23 |
+
cwd=repo.local_dir,
|
24 |
+
)
|
25 |
+
|
26 |
+
stdout, stderr = process.communicate()
|
27 |
+
return_code = process.poll()
|
28 |
+
process.kill()
|
29 |
+
|
30 |
+
if len(stderr):
|
31 |
+
print(stderr)
|
32 |
+
|
33 |
+
if return_code:
|
34 |
+
raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr)
|
35 |
+
|
36 |
+
except subprocess.CalledProcessError as exc:
|
37 |
+
raise EnvironmentError(exc.stderr)
|
38 |
+
|
39 |
+
return repo.git_head_commit_url()
|