Tristan Thrush commited on
Commit
dbe89c5
Β·
0 Parent(s):

first commit

Browse files
Files changed (9) hide show
  1. .env.example +3 -0
  2. .gitignore +163 -0
  3. Makefile +10 -0
  4. README.md +70 -0
  5. app.py +236 -0
  6. collect.py +133 -0
  7. config.py.example +6 -0
  8. requirements.txt +5 -0
  9. utils.py +39 -0
.env.example ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ DATASET_REPO_URL="https://huggingface.co/datasets/{DATASET_ID}"
2
+ FORCE_PUSH="no"
3
+ HF_TOKEN="hf_xxx"
.gitignore ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ # Local development
163
+ data/
Makefile ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: style quality
2
+
3
+ style:
4
+ python -m black --line-length 119 --target-version py38 .
5
+ python -m isort .
6
+
7
+ quality:
8
+ python -m black --check --line-length 119 --target-version py38 .
9
+ python -m isort --check-only .
10
+ python -m flake8 --max-line-length 119 .
README.md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: RLHF
3
+ emoji: 🏒
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ An RLHF interface for data collection with [Amazon Mechanical Turk](https://www.mturk.com) and Gradio.
13
+
14
+ ## Instructions for someone to use for their own project
15
+
16
+ ### Install dependencies
17
+
18
+ First, create a Python virtual environment and install the project's dependencies as follows:
19
+
20
+ ```bash
21
+ python -m pip install -r requirements.txt
22
+ ```
23
+
24
+ ### Setting up the Space
25
+
26
+ 1. Clone this repo and deploy it on your own Hugging Face space.
27
+ 2. Add the following secrets to your space:
28
+ - `HF_TOKEN`: One of your Hugging Face tokens.
29
+ - `DATASET_REPO_URL`: The url to an empty dataset that you created the hub. It
30
+ can be a private or public dataset.
31
+ - `FORCE_PUSH`: "yes"
32
+ When you run this space on mturk and when people visit your space on
33
+ huggingface.co, the app will use your token to automatically store new HITs
34
+ in your dataset. Setting `FORCE_PUSH` to "yes" ensures that your repo will
35
+ force push changes to the dataset during data collection. Otherwise,
36
+ accidental manual changes to your dataset could result in your space getting
37
+ merge conflicts as it automatically tries to push the dataset to the hub. For
38
+ local development, add these three keys to a `.env` file, and consider setting
39
+ `FORCE_PUSH` to "no".
40
+
41
+ To launch the Space locally, run:
42
+
43
+ ```bash
44
+ python app.py
45
+ ```
46
+
47
+ The app will then be available at a local address, such as http://127.0.0.1:7860
48
+
49
+ ### Running data collection*
50
+
51
+ 1. On your local repo that you pulled, create a copy of `config.py.example`,
52
+ just called `config.py`. Now, put keys from your AWS account in `config.py`.
53
+ These keys should be for an AWS account that has the
54
+ AmazonMechanicalTurkFullAccess permission. You also need to
55
+ create an mturk requestor account associated with your AWS account.
56
+ 2. Run `python collect.py` locally.
57
+
58
+ ### Profit
59
+ Now, you should be watching hits come into your Hugging Face dataset
60
+ automatically!
61
+
62
+ ### Tips and tricks
63
+
64
+ - Use caution while doing local development of your space and
65
+ simultaneously running it on mturk. Consider setting `FORCE_PUSH` to "no" in
66
+ your local `.env` file.
67
+ - huggingface spaces have limited computational resources and memory. If you
68
+ run too many HITs and/or assignments at once, then you could encounter issues.
69
+ You could also encounter issues if you are trying to create a dataset that is
70
+ very large. Check the log of your space for any errors that could be happening.
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Basic example for doing model-in-the-loop dynamic adversarial data collection
2
+ # using Gradio Blocks.
3
+ import json
4
+ import os
5
+ import threading
6
+ import uuid
7
+ from pathlib import Path
8
+ from urllib.parse import parse_qs
9
+ from datasets import load_dataset
10
+ import gradio as gr
11
+ from dotenv import load_dotenv
12
+ from huggingface_hub import Repository
13
+ import random
14
+
15
+ from utils import force_git_push
16
+
17
+
18
+ # These variables are for storing the MTurk HITs in a Hugging Face dataset.
19
+ if Path(".env").is_file():
20
+ load_dotenv(".env")
21
+ DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
22
+ FORCE_PUSH = os.getenv("FORCE_PUSH")
23
+ HF_TOKEN = os.getenv("HF_TOKEN")
24
+ PROMPT_TEMPLATES = Path("prompt_templates")
25
+
26
+ DATA_FILENAME = "data.jsonl"
27
+ DATA_FILE = os.path.join("data", DATA_FILENAME)
28
+ repo = Repository(local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN)
29
+ ds = load_dataset("HuggingFaceH4/instruction-pilot-outputs", split="train", use_auth_token=HF_TOKEN)
30
+
31
+ TOTAL_CNT = 10 # How many user inputs per HIT
32
+
33
+ # This function pushes the HIT data written in data.jsonl to our Hugging Face
34
+ # dataset every minute. Adjust the frequency to suit your needs.
35
+ PUSH_FREQUENCY = 60
36
+
37
+
38
+ def asynchronous_push(f_stop):
39
+ if repo.is_repo_clean():
40
+ print("Repo currently clean. Ignoring push_to_hub")
41
+ else:
42
+ repo.git_add(auto_lfs_track=True)
43
+ repo.git_commit("Auto commit by space")
44
+ if FORCE_PUSH == "yes":
45
+ force_git_push(repo)
46
+ else:
47
+ repo.git_push()
48
+ if not f_stop.is_set():
49
+ # call again in 60 seconds
50
+ threading.Timer(PUSH_FREQUENCY, asynchronous_push, [f_stop]).start()
51
+
52
+
53
+ f_stop = threading.Event()
54
+ asynchronous_push(f_stop)
55
+
56
+ demo = gr.Blocks()
57
+
58
+ def random_sample_with_least_annotated_examples_first():
59
+ annotations = open(DATA_FILE, "r").readlines()
60
+ id_to_count = {}
61
+ for line in annotations:
62
+ annotation = json.loads(line)
63
+ # Only include annotations by actual turkers in the count.
64
+ if annotation["assignmentId"] != "":
65
+ example_id = annotation["id"]
66
+ id_to_count[example_id] = id_to_count.get(example_id, 0) + 1
67
+ ds_with_annotation_counts = ds.map(lambda example: {"annotation_count": id_to_count.get(example["id"], 0)})
68
+ ds_with_annotation_counts = ds_with_annotation_counts.shuffle()
69
+ ds_with_annotation_counts = ds_with_annotation_counts.sort("annotation_count")
70
+ example = ds_with_annotation_counts.select([0])[0]
71
+ # We only want to give the annotator 2 choices, so we sample 2 outputs without replacement.
72
+ example["outputs"] = random.sample(example["outputs"], 2)
73
+ return example
74
+
75
+
76
+ with demo:
77
+ dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
78
+
79
+ initial_sample = random_sample_with_least_annotated_examples_first()
80
+
81
+ # We keep track of state as a JSON
82
+ state_dict = {
83
+ "taskId": str(uuid.uuid4()),
84
+ "assignmentId": "",
85
+ "cnt": 0,
86
+ "data": [initial_sample],
87
+ }
88
+ state = gr.JSON(state_dict, visible=False)
89
+
90
+ gr.Markdown("# Choose the most helpful and honest response.")
91
+
92
+ state_display = gr.Markdown(f"Your messages: 0/{TOTAL_CNT}")
93
+
94
+ def _select_response(selected_response, state, dummy):
95
+ if selected_response == "":
96
+ # Don't do anything if the worker didn't select anything yet.
97
+ return (
98
+ gr.update(),
99
+ gr.update(),
100
+ gr.update(),
101
+ gr.update(),
102
+ gr.update(),
103
+ gr.update(),
104
+ state,
105
+ dummy,
106
+ )
107
+ state["cnt"] += 1
108
+ state_display = f"Your messages: {state['cnt']}/{TOTAL_CNT}"
109
+ done = state["cnt"] == TOTAL_CNT
110
+ state["data"][-1]["selected_response"] = selected_response
111
+ if state["cnt"] == TOTAL_CNT:
112
+ # Write the HIT data to our local dataset because the worker has
113
+ # submitted everything now.
114
+ with open(DATA_FILE, "a") as jsonlfile:
115
+ json_data_with_assignment_id = [
116
+ json.dumps(
117
+ dict(
118
+ {"assignmentId": state["assignmentId"], "taskId": state["taskId"]},
119
+ **datum,
120
+ )
121
+ )
122
+ for datum in state["data"]
123
+ ]
124
+ jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
125
+ query = parse_qs(dummy[1:])
126
+ if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
127
+ # It seems that someone is using this app on mturk. We need to
128
+ # store the assignmentId in the state before submit_hit_button
129
+ # is clicked. We can do this here in _predict. We need to save the
130
+ # assignmentId so that the turker can get credit for their HIT.
131
+ state["assignmentId"] = query["assignmentId"][0]
132
+ toggle_final_submit = gr.update(visible=done)
133
+ toggle_final_submit_preview = gr.update(visible=False)
134
+ else:
135
+ toggle_final_submit_preview = gr.update(visible=done)
136
+ toggle_final_submit = gr.update(visible=False)
137
+
138
+ toggle_select_response_button = gr.update(visible=not done)
139
+
140
+ new_sample = random_sample_with_least_annotated_examples_first()
141
+ new_outputs = [obj["output"] for obj in new_sample["outputs"]]
142
+ state["data"].append(new_sample)
143
+ past_conversation = gr.update(value=new_sample["prompt"])
144
+ select_response = gr.update(choices=new_outputs, value="")
145
+
146
+ return (
147
+ past_conversation,
148
+ select_response,
149
+ toggle_select_response_button,
150
+ toggle_final_submit,
151
+ toggle_final_submit_preview,
152
+ state_display,
153
+ state,
154
+ dummy,
155
+ )
156
+
157
+ # Input fields
158
+ past_conversation = gr.Markdown(value=initial_sample["prompt"])
159
+ initial_outputs = [obj["output"] for obj in initial_sample["outputs"]]
160
+ select_response = gr.Radio(
161
+ choices=initial_outputs, label="Choose the most helpful and honest response"
162
+ )
163
+ select_response_button = gr.Button("Submit Response")
164
+ submit_hit_button = gr.Button("Submit HIT", visible=False)
165
+ submit_hit_button_preview = gr.Button(
166
+ "Submit Work (preview mode; no MTurk HIT credit, but your examples will still be stored)",
167
+ visible=False,
168
+ )
169
+
170
+ # Button event handlers
171
+ get_window_location_search_js = """
172
+ function(select_response, state, dummy) {
173
+ return [select_response, state, window.location.search];
174
+ }
175
+ """
176
+
177
+ select_response_button.click(
178
+ _select_response,
179
+ inputs=[select_response, state, dummy],
180
+ outputs=[
181
+ past_conversation,
182
+ select_response,
183
+ select_response_button,
184
+ submit_hit_button,
185
+ submit_hit_button_preview,
186
+ state_display,
187
+ state,
188
+ dummy,
189
+ ],
190
+ _js=get_window_location_search_js,
191
+ )
192
+
193
+ post_hit_js = """
194
+ function(state) {
195
+ // If there is an assignmentId, then the submitter is on mturk
196
+ // and has accepted the HIT. So, we need to submit their HIT.
197
+ const form = document.createElement('form');
198
+ form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
199
+ form.method = 'post';
200
+ for (const key in state) {
201
+ const hiddenField = document.createElement('input');
202
+ hiddenField.type = 'hidden';
203
+ hiddenField.name = key;
204
+ hiddenField.value = state[key];
205
+ form.appendChild(hiddenField);
206
+ };
207
+ document.body.appendChild(form);
208
+ form.submit();
209
+ return state;
210
+ }
211
+ """
212
+
213
+ submit_hit_button.click(
214
+ lambda state: state,
215
+ inputs=[state],
216
+ outputs=[state],
217
+ _js=post_hit_js,
218
+ )
219
+
220
+ refresh_app_js = """
221
+ function(state) {
222
+ // The following line here loads the app again so the user can
223
+ // enter in another preview-mode "HIT".
224
+ window.location.href = window.location.href;
225
+ return state;
226
+ }
227
+ """
228
+
229
+ submit_hit_button_preview.click(
230
+ lambda state: state,
231
+ inputs=[state],
232
+ outputs=[state],
233
+ _js=refresh_app_js,
234
+ )
235
+
236
+ demo.launch()
collect.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from os import path
3
+
4
+ import boto3
5
+ from boto.mturk.question import ExternalQuestion
6
+ from config import MTURK_KEY, MTURK_SECRET
7
+
8
+ parser = argparse.ArgumentParser()
9
+ parser.add_argument("--mturk_region", default="us-east-1", help="The region for mturk (default: us-east-1)")
10
+ parser.add_argument(
11
+ "--space_name",
12
+ default="Tristan/static-rlhf-interface",
13
+ help="Name of the accompanying Hugging Face space (default: huggingface/rlhf-interface)",
14
+ )
15
+ parser.add_argument("--num_hits", type=int, default=5, help="The number of HITs.")
16
+ parser.add_argument(
17
+ "--num_assignments", type=int, default=1, help="The number of times that the HIT can be accepted and completed."
18
+ )
19
+ parser.add_argument(
20
+ "--live_mode",
21
+ action="store_true",
22
+ help="""
23
+ Whether to run in live mode with real turkers. This will charge your account money.
24
+ If you don't use this flag, the HITs will be deployed on the sandbox version of mturk,
25
+ which will not charge your account money.
26
+ """,
27
+ )
28
+ parser.add_argument(
29
+ "--refresh_qualification_test",
30
+ action="store_true",
31
+ help="""
32
+ Whether to refresh the qualification test. If you've made edits to the test
33
+ xml files, it is necessary to do this.
34
+ """,
35
+ )
36
+ parser.add_argument(
37
+ "--custom_qualification_test",
38
+ action="store_true",
39
+ help="""
40
+ Whether to require the custom qualification test.
41
+ """,
42
+ )
43
+ parser.add_argument(
44
+ "--master_turkers",
45
+ action="store_true",
46
+ help="""
47
+ Whether to only use turkers with the master qualification.
48
+ """,
49
+ )
50
+ parser.add_argument(
51
+ "--us_turkers",
52
+ action="store_true",
53
+ help="""
54
+ Whether to only use US-based turkers.
55
+ """,
56
+ )
57
+
58
+ args = parser.parse_args()
59
+
60
+ MTURK_URL = f"https://mturk-requester{'' if args.live_mode else '-sandbox'}.{args.mturk_region}.amazonaws.com"
61
+
62
+ mturk = boto3.client(
63
+ "mturk",
64
+ aws_access_key_id=MTURK_KEY,
65
+ aws_secret_access_key=MTURK_SECRET,
66
+ region_name=args.mturk_region,
67
+ endpoint_url=MTURK_URL,
68
+ )
69
+
70
+ # This is the URL that makes the space embeddable in an mturk iframe
71
+ question = ExternalQuestion(f"https://hf.space/embed/{args.space_name}/+?__theme=light", frame_height=600)
72
+
73
+ qualification_requirements=[]
74
+
75
+ if args.master_turkers:
76
+ qualification_requirements.append({
77
+ QualificationTypeId: '2F1QJWKUDD8XADTFD2Q0G6UTO95ALH',
78
+ Comparator: 'Exists'
79
+ })
80
+
81
+ if args.us_turkers:
82
+ qualification_requirements.append({
83
+ QualificationTypeId: '00000000000000000071',
84
+ Comparator: 'In',
85
+ LocaleValues: [
86
+ { Country: "US" },
87
+ ]
88
+ })
89
+
90
+ if args.custom_qualification_test:
91
+ qualification_type_id = (
92
+ open("qualification_type_id.txt", "r").read() if path.exists("qualification_type_id.txt") else None
93
+ )
94
+ if args.refresh_qualification_test or qualification_type_id is None:
95
+ if qualification_type_id is not None:
96
+ mturk.delete_qualification_type(QualificationTypeId=qualification_type_id)
97
+ response = mturk.create_qualification_type(
98
+ Name="rlhf--qualification",
99
+ Keywords="RLHF qualification",
100
+ Description="Qualification test for RLHF task.",
101
+ QualificationTypeStatus="Active",
102
+ Test=open("qualification_questions.xml", mode="r").read(),
103
+ AnswerKey=open("qualification_answers.xml", mode="r").read(),
104
+ TestDurationInSeconds=3600,
105
+ AutoGranted=False,
106
+ )
107
+ qualification_type_id = response["QualificationType"]["QualificationTypeId"]
108
+ open("qualification_type_id.txt", "w+").write(qualification_type_id)
109
+ qualification_requirements.append({
110
+ "QualificationTypeId": qualification_type_id,
111
+ "Comparator": "Exists",
112
+ "RequiredToPreview": False,
113
+ "ActionsGuarded": "Accept",
114
+ })
115
+
116
+ for i in range(args.num_hits):
117
+ new_hit = mturk.create_hit(
118
+ Title="RLHF HIT",
119
+ Description="Interact with an AI",
120
+ Keywords="chatbot",
121
+ Reward="0.25",
122
+ MaxAssignments=args.num_assignments,
123
+ LifetimeInSeconds=172800,
124
+ AssignmentDurationInSeconds=600,
125
+ AutoApprovalDelayInSeconds=14400,
126
+ Question=question.get_as_xml(),
127
+ QualificationRequirements=qualification_requirements,
128
+ )
129
+
130
+ print(
131
+ f"HIT Group Link: https://worker{'' if args.live_mode else 'sandbox'}.mturk.com/mturk/preview?groupId="
132
+ + new_hit["HIT"]["HITGroupId"]
133
+ )
config.py.example ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Fill in the information and rename this file config.py
2
+ # You can obtain the key and secret in the AWS Identity
3
+ # and Access Management (IAM) panel.
4
+
5
+ MTURK_KEY = ''
6
+ MTURK_SECRET = ''
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ boto3==1.24.32
2
+ boto==2.49.0
3
+ huggingface_hub==0.8.1
4
+ python-dotenv==0.20.0
5
+ datasets==2.9.0
utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import subprocess
2
+
3
+ from huggingface_hub.repository import _lfs_log_progress
4
+
5
+
6
+ def force_git_push(
7
+ repo,
8
+ ):
9
+ """
10
+ force a simple git push
11
+ Blocking. Will return url to commit on remote
12
+ repo.
13
+ """
14
+ command = "git push --force"
15
+
16
+ try:
17
+ with _lfs_log_progress():
18
+ process = subprocess.Popen(
19
+ command.split(),
20
+ stderr=subprocess.PIPE,
21
+ stdout=subprocess.PIPE,
22
+ encoding="utf-8",
23
+ cwd=repo.local_dir,
24
+ )
25
+
26
+ stdout, stderr = process.communicate()
27
+ return_code = process.poll()
28
+ process.kill()
29
+
30
+ if len(stderr):
31
+ print(stderr)
32
+
33
+ if return_code:
34
+ raise subprocess.CalledProcessError(return_code, process.args, output=stdout, stderr=stderr)
35
+
36
+ except subprocess.CalledProcessError as exc:
37
+ raise EnvironmentError(exc.stderr)
38
+
39
+ return repo.git_head_commit_url()