Tom Aarsen commited on
Commit
a4605b3
·
1 Parent(s): f9c9b72

Update training script to separate dataset loading & training

Browse files
Files changed (1) hide show
  1. train.py +192 -164
train.py CHANGED
@@ -19,6 +19,197 @@ logging.basicConfig(
19
  random.seed(12)
20
 
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  def main():
23
  # 1. Load a model to finetune with 2. (Optional) model card data
24
  static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024)
@@ -31,170 +222,7 @@ def main():
31
  )
32
 
33
  # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
34
- print("Loading wikititles dataset...")
35
- wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train")
36
- wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12)
37
- wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"]
38
- wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"]
39
- print("Loaded wikititles dataset.")
40
-
41
- print("Loading tatoeba dataset...")
42
- tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train")
43
- tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12)
44
- tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"]
45
- tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"]
46
- print("Loaded tatoeba dataset.")
47
-
48
- print("Loading talks dataset...")
49
- talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train")
50
- talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12)
51
- talks_train_dataset: Dataset = talks_dataset_dict["train"]
52
- talks_eval_dataset: Dataset = talks_dataset_dict["test"]
53
- print("Loaded talks dataset.")
54
-
55
- print("Loading europarl dataset...")
56
- europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]")
57
- europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12)
58
- europarl_train_dataset: Dataset = europarl_dataset_dict["train"]
59
- europarl_eval_dataset: Dataset = europarl_dataset_dict["test"]
60
- print("Loaded europarl dataset.")
61
-
62
- print("Loading global voices dataset...")
63
- global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train")
64
- global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12)
65
- global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"]
66
- global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"]
67
- print("Loaded global voices dataset.")
68
-
69
- print("Loading muse dataset...")
70
- muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train")
71
- muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12)
72
- muse_train_dataset: Dataset = muse_dataset_dict["train"]
73
- muse_eval_dataset: Dataset = muse_dataset_dict["test"]
74
- print("Loaded muse dataset.")
75
-
76
- print("Loading wikimatrix dataset...")
77
- wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train")
78
- wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12)
79
- wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"]
80
- wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"]
81
- print("Loaded wikimatrix dataset.")
82
-
83
- print("Loading opensubtitles dataset...")
84
- opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]")
85
- opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12)
86
- opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"]
87
- opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"]
88
- print("Loaded opensubtitles dataset.")
89
-
90
- print("Loading stackexchange dataset...")
91
- stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train")
92
- stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12)
93
- stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"]
94
- stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"]
95
- print("Loaded stackexchange dataset.")
96
-
97
- print("Loading quora dataset...")
98
- quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train")
99
- quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12)
100
- quora_train_dataset: Dataset = quora_dataset_dict["train"]
101
- quora_eval_dataset: Dataset = quora_dataset_dict["test"]
102
- print("Loaded quora dataset.")
103
-
104
- print("Loading wikianswers duplicates dataset...")
105
- wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]")
106
- wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12)
107
- wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"]
108
- wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"]
109
- print("Loaded wikianswers duplicates dataset.")
110
-
111
- print("Loading all nli dataset...")
112
- all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
113
- all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
114
- print("Loaded all nli dataset.")
115
-
116
- print("Loading simple wiki dataset...")
117
- simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train")
118
- simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12)
119
- simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"]
120
- simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"]
121
- print("Loaded simple wiki dataset.")
122
-
123
- print("Loading altlex dataset...")
124
- altlex_dataset = load_dataset("sentence-transformers/altlex", split="train")
125
- altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12)
126
- altlex_train_dataset: Dataset = altlex_dataset_dict["train"]
127
- altlex_eval_dataset: Dataset = altlex_dataset_dict["test"]
128
- print("Loaded altlex dataset.")
129
-
130
- print("Loading flickr30k captions dataset...")
131
- flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train")
132
- flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12)
133
- flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"]
134
- flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"]
135
- print("Loaded flickr30k captions dataset.")
136
-
137
- print("Loading coco captions dataset...")
138
- coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train")
139
- coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12)
140
- coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"]
141
- coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"]
142
- print("Loaded coco captions dataset.")
143
-
144
- print("Loading nli for simcse dataset...")
145
- nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train")
146
- nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12)
147
- nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"]
148
- nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"]
149
- print("Loaded nli for simcse dataset.")
150
-
151
- print("Loading negation dataset...")
152
- negation_dataset = load_dataset("jinaai/negation-dataset", split="train")
153
- negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12)
154
- negation_train_dataset: Dataset = negation_dataset_dict["train"]
155
- negation_eval_dataset: Dataset = negation_dataset_dict["test"]
156
- print("Loaded negation dataset.")
157
-
158
- train_dataset = DatasetDict({
159
- "wikititles": wikititles_train_dataset,
160
- "tatoeba": tatoeba_train_dataset,
161
- "talks": talks_train_dataset,
162
- "europarl": europarl_train_dataset,
163
- "global_voices": global_voices_train_dataset,
164
- "muse": muse_train_dataset,
165
- "wikimatrix": wikimatrix_train_dataset,
166
- "opensubtitles": opensubtitles_train_dataset,
167
- "stackexchange": stackexchange_train_dataset,
168
- "quora": quora_train_dataset,
169
- "wikianswers_duplicates": wikianswers_duplicates_train_dataset,
170
- "all_nli": all_nli_train_dataset,
171
- "simple_wiki": simple_wiki_train_dataset,
172
- "altlex": altlex_train_dataset,
173
- "flickr30k_captions": flickr30k_captions_train_dataset,
174
- "coco_captions": coco_captions_train_dataset,
175
- "nli_for_simcse": nli_for_simcse_train_dataset,
176
- "negation": negation_train_dataset,
177
- })
178
- eval_dataset = DatasetDict({
179
- "wikititles": wikititles_eval_dataset,
180
- "tatoeba": tatoeba_eval_dataset,
181
- "talks": talks_eval_dataset,
182
- "europarl": europarl_eval_dataset,
183
- "global_voices": global_voices_eval_dataset,
184
- "muse": muse_eval_dataset,
185
- "wikimatrix": wikimatrix_eval_dataset,
186
- "opensubtitles": opensubtitles_eval_dataset,
187
- "stackexchange": stackexchange_eval_dataset,
188
- "quora": quora_eval_dataset,
189
- "wikianswers_duplicates": wikianswers_duplicates_eval_dataset,
190
- "all_nli": all_nli_eval_dataset,
191
- "simple_wiki": simple_wiki_eval_dataset,
192
- "altlex": altlex_eval_dataset,
193
- "flickr30k_captions": flickr30k_captions_eval_dataset,
194
- "coco_captions": coco_captions_eval_dataset,
195
- "nli_for_simcse": nli_for_simcse_eval_dataset,
196
- "negation": negation_eval_dataset,
197
- })
198
  print(train_dataset)
199
 
200
  # 4. Define a loss function
 
19
  random.seed(12)
20
 
21
 
22
+ def load_train_eval_datasets():
23
+ """
24
+ Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.
25
+
26
+ Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
27
+ """
28
+ try:
29
+ train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
30
+ eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
31
+ return train_dataset, eval_dataset
32
+ except FileNotFoundError:
33
+ print("Loading wikititles dataset...")
34
+ wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train")
35
+ wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12)
36
+ wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"]
37
+ wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"]
38
+ print("Loaded wikititles dataset.")
39
+
40
+ print("Loading tatoeba dataset...")
41
+ tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train")
42
+ tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12)
43
+ tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"]
44
+ tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"]
45
+ print("Loaded tatoeba dataset.")
46
+
47
+ print("Loading talks dataset...")
48
+ talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train")
49
+ talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12)
50
+ talks_train_dataset: Dataset = talks_dataset_dict["train"]
51
+ talks_eval_dataset: Dataset = talks_dataset_dict["test"]
52
+ print("Loaded talks dataset.")
53
+
54
+ print("Loading europarl dataset...")
55
+ europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]")
56
+ europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12)
57
+ europarl_train_dataset: Dataset = europarl_dataset_dict["train"]
58
+ europarl_eval_dataset: Dataset = europarl_dataset_dict["test"]
59
+ print("Loaded europarl dataset.")
60
+
61
+ print("Loading global voices dataset...")
62
+ global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train")
63
+ global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12)
64
+ global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"]
65
+ global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"]
66
+ print("Loaded global voices dataset.")
67
+
68
+ print("Loading jw300 dataset...")
69
+ jw300_dataset = load_dataset("sentence-transformers/parallel-sentences-jw300", "all", split="train")
70
+ jw300_dataset_dict = jw300_dataset.train_test_split(test_size=10_000, seed=12)
71
+ jw300_train_dataset: Dataset = jw300_dataset_dict["train"]
72
+ jw300_eval_dataset: Dataset = jw300_dataset_dict["test"]
73
+ print("Loaded jw300 dataset.")
74
+
75
+ print("Loading muse dataset...")
76
+ muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train")
77
+ muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12)
78
+ muse_train_dataset: Dataset = muse_dataset_dict["train"]
79
+ muse_eval_dataset: Dataset = muse_dataset_dict["test"]
80
+ print("Loaded muse dataset.")
81
+
82
+ print("Loading wikimatrix dataset...")
83
+ wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train")
84
+ wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12)
85
+ wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"]
86
+ wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"]
87
+ print("Loaded wikimatrix dataset.")
88
+
89
+ print("Loading opensubtitles dataset...")
90
+ opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]")
91
+ opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12)
92
+ opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"]
93
+ opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"]
94
+ print("Loaded opensubtitles dataset.")
95
+
96
+ print("Loading stackexchange dataset...")
97
+ stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train")
98
+ stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12)
99
+ stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"]
100
+ stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"]
101
+ print("Loaded stackexchange dataset.")
102
+
103
+ print("Loading quora dataset...")
104
+ quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train")
105
+ quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12)
106
+ quora_train_dataset: Dataset = quora_dataset_dict["train"]
107
+ quora_eval_dataset: Dataset = quora_dataset_dict["test"]
108
+ print("Loaded quora dataset.")
109
+
110
+ print("Loading wikianswers duplicates dataset...")
111
+ wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]")
112
+ wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12)
113
+ wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"]
114
+ wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"]
115
+ print("Loaded wikianswers duplicates dataset.")
116
+
117
+ print("Loading all nli dataset...")
118
+ all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
119
+ all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
120
+ print("Loaded all nli dataset.")
121
+
122
+ print("Loading simple wiki dataset...")
123
+ simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train")
124
+ simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12)
125
+ simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"]
126
+ simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"]
127
+ print("Loaded simple wiki dataset.")
128
+
129
+ print("Loading altlex dataset...")
130
+ altlex_dataset = load_dataset("sentence-transformers/altlex", split="train")
131
+ altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12)
132
+ altlex_train_dataset: Dataset = altlex_dataset_dict["train"]
133
+ altlex_eval_dataset: Dataset = altlex_dataset_dict["test"]
134
+ print("Loaded altlex dataset.")
135
+
136
+ print("Loading flickr30k captions dataset...")
137
+ flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train")
138
+ flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12)
139
+ flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"]
140
+ flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"]
141
+ print("Loaded flickr30k captions dataset.")
142
+
143
+ print("Loading coco captions dataset...")
144
+ coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train")
145
+ coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12)
146
+ coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"]
147
+ coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"]
148
+ print("Loaded coco captions dataset.")
149
+
150
+ print("Loading nli for simcse dataset...")
151
+ nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train")
152
+ nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12)
153
+ nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"]
154
+ nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"]
155
+ print("Loaded nli for simcse dataset.")
156
+
157
+ print("Loading negation dataset...")
158
+ negation_dataset = load_dataset("jinaai/negation-dataset", split="train")
159
+ negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12)
160
+ negation_train_dataset: Dataset = negation_dataset_dict["train"]
161
+ negation_eval_dataset: Dataset = negation_dataset_dict["test"]
162
+ print("Loaded negation dataset.")
163
+
164
+ train_dataset = DatasetDict({
165
+ "wikititles": wikititles_train_dataset,
166
+ "tatoeba": tatoeba_train_dataset,
167
+ "talks": talks_train_dataset,
168
+ "europarl": europarl_train_dataset,
169
+ "global_voices": global_voices_train_dataset,
170
+ "jw300": jw300_train_dataset,
171
+ "muse": muse_train_dataset,
172
+ "wikimatrix": wikimatrix_train_dataset,
173
+ "opensubtitles": opensubtitles_train_dataset,
174
+ "stackexchange": stackexchange_train_dataset,
175
+ "quora": quora_train_dataset,
176
+ "wikianswers_duplicates": wikianswers_duplicates_train_dataset,
177
+ "all_nli": all_nli_train_dataset,
178
+ "simple_wiki": simple_wiki_train_dataset,
179
+ "altlex": altlex_train_dataset,
180
+ "flickr30k_captions": flickr30k_captions_train_dataset,
181
+ "coco_captions": coco_captions_train_dataset,
182
+ "nli_for_simcse": nli_for_simcse_train_dataset,
183
+ "negation": negation_train_dataset,
184
+ })
185
+ eval_dataset = DatasetDict({
186
+ "wikititles": wikititles_eval_dataset,
187
+ "tatoeba": tatoeba_eval_dataset,
188
+ "talks": talks_eval_dataset,
189
+ "europarl": europarl_eval_dataset,
190
+ "global_voices": global_voices_eval_dataset,
191
+ "jw300": jw300_eval_dataset,
192
+ "muse": muse_eval_dataset,
193
+ "wikimatrix": wikimatrix_eval_dataset,
194
+ "opensubtitles": opensubtitles_eval_dataset,
195
+ "stackexchange": stackexchange_eval_dataset,
196
+ "quora": quora_eval_dataset,
197
+ "wikianswers_duplicates": wikianswers_duplicates_eval_dataset,
198
+ "all_nli": all_nli_eval_dataset,
199
+ "simple_wiki": simple_wiki_eval_dataset,
200
+ "altlex": altlex_eval_dataset,
201
+ "flickr30k_captions": flickr30k_captions_eval_dataset,
202
+ "coco_captions": coco_captions_eval_dataset,
203
+ "nli_for_simcse": nli_for_simcse_eval_dataset,
204
+ "negation": negation_eval_dataset,
205
+ })
206
+
207
+ train_dataset.save_to_disk("datasets/train_dataset")
208
+ eval_dataset.save_to_disk("datasets/eval_dataset")
209
+
210
+ # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
211
+ quit()
212
+
213
  def main():
214
  # 1. Load a model to finetune with 2. (Optional) model card data
215
  static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024)
 
222
  )
223
 
224
  # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
225
+ train_dataset, eval_dataset = load_train_eval_datasets()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  print(train_dataset)
227
 
228
  # 4. Define a loss function