tomaarsen HF staff commited on
Commit
a0fb398
·
verified ·
1 Parent(s): bae9c8b

Create train.py

Browse files
Files changed (1) hide show
  1. train.py +256 -0
train.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import logging
3
+ from datasets import load_dataset, Dataset, DatasetDict
4
+ from sentence_transformers import (
5
+ SentenceTransformer,
6
+ SentenceTransformerTrainer,
7
+ SentenceTransformerTrainingArguments,
8
+ SentenceTransformerModelCardData,
9
+ )
10
+ from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
11
+ from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
12
+ from sentence_transformers.models.StaticEmbedding import StaticEmbedding
13
+
14
+ from transformers import AutoTokenizer
15
+
16
+ logging.basicConfig(
17
+ format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
18
+ )
19
+ random.seed(12)
20
+
21
+
22
+ def main():
23
+ # 1. Load a model to finetune with 2. (Optional) model card data
24
+ static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024)
25
+ model = SentenceTransformer(
26
+ modules=[static_embedding],
27
+ model_card_data=SentenceTransformerModelCardData(
28
+ license="apache-2.0",
29
+ model_name="Static Embeddings with BERT Multilingual uncased tokenizer finetuned on various datasets",
30
+ ),
31
+ )
32
+
33
+ # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
34
+ print("Loading wikititles dataset...")
35
+ wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train")
36
+ wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12)
37
+ wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"]
38
+ wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"]
39
+ print("Loaded wikititles dataset.")
40
+
41
+ print("Loading tatoeba dataset...")
42
+ tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train")
43
+ tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12)
44
+ tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"]
45
+ tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"]
46
+ print("Loaded tatoeba dataset.")
47
+
48
+ print("Loading talks dataset...")
49
+ talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train")
50
+ talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12)
51
+ talks_train_dataset: Dataset = talks_dataset_dict["train"]
52
+ talks_eval_dataset: Dataset = talks_dataset_dict["test"]
53
+ print("Loaded talks dataset.")
54
+
55
+ print("Loading europarl dataset...")
56
+ europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]")
57
+ europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12)
58
+ europarl_train_dataset: Dataset = europarl_dataset_dict["train"]
59
+ europarl_eval_dataset: Dataset = europarl_dataset_dict["test"]
60
+ print("Loaded europarl dataset.")
61
+
62
+ print("Loading global voices dataset...")
63
+ global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train")
64
+ global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12)
65
+ global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"]
66
+ global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"]
67
+ print("Loaded global voices dataset.")
68
+
69
+ print("Loading jw300 dataset...")
70
+ jw300_dataset = load_dataset("sentence-transformers/parallel-sentences-jw300", "all", split="train")
71
+ jw300_dataset_dict = jw300_dataset.train_test_split(test_size=10_000, seed=12)
72
+ jw300_train_dataset: Dataset = jw300_dataset_dict["train"]
73
+ jw300_eval_dataset: Dataset = jw300_dataset_dict["test"]
74
+ print("Loaded jw300 dataset.")
75
+
76
+ print("Loading muse dataset...")
77
+ muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train")
78
+ muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12)
79
+ muse_train_dataset: Dataset = muse_dataset_dict["train"]
80
+ muse_eval_dataset: Dataset = muse_dataset_dict["test"]
81
+ print("Loaded muse dataset.")
82
+
83
+ print("Loading wikimatrix dataset...")
84
+ wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train")
85
+ wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12)
86
+ wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"]
87
+ wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"]
88
+ print("Loaded wikimatrix dataset.")
89
+
90
+ print("Loading opensubtitles dataset...")
91
+ opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]")
92
+ opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12)
93
+ opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"]
94
+ opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"]
95
+ print("Loaded opensubtitles dataset.")
96
+
97
+ print("Loading stackexchange dataset...")
98
+ stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train")
99
+ stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12)
100
+ stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"]
101
+ stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"]
102
+ print("Loaded stackexchange dataset.")
103
+
104
+ print("Loading quora dataset...")
105
+ quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train")
106
+ quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12)
107
+ quora_train_dataset: Dataset = quora_dataset_dict["train"]
108
+ quora_eval_dataset: Dataset = quora_dataset_dict["test"]
109
+ print("Loaded quora dataset.")
110
+
111
+ print("Loading wikianswers duplicates dataset...")
112
+ wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]")
113
+ wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12)
114
+ wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"]
115
+ wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"]
116
+ print("Loaded wikianswers duplicates dataset.")
117
+
118
+ print("Loading all nli dataset...")
119
+ all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
120
+ all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
121
+ print("Loaded all nli dataset.")
122
+
123
+ print("Loading simple wiki dataset...")
124
+ simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train")
125
+ simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12)
126
+ simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"]
127
+ simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"]
128
+ print("Loaded simple wiki dataset.")
129
+
130
+ print("Loading altlex dataset...")
131
+ altlex_dataset = load_dataset("sentence-transformers/altlex", split="train")
132
+ altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12)
133
+ altlex_train_dataset: Dataset = altlex_dataset_dict["train"]
134
+ altlex_eval_dataset: Dataset = altlex_dataset_dict["test"]
135
+ print("Loaded altlex dataset.")
136
+
137
+ print("Loading flickr30k captions dataset...")
138
+ flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train")
139
+ flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12)
140
+ flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"]
141
+ flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"]
142
+ print("Loaded flickr30k captions dataset.")
143
+
144
+ print("Loading coco captions dataset...")
145
+ coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train")
146
+ coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12)
147
+ coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"]
148
+ coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"]
149
+ print("Loaded coco captions dataset.")
150
+
151
+ print("Loading nli for simcse dataset...")
152
+ nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train")
153
+ nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12)
154
+ nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"]
155
+ nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"]
156
+ print("Loaded nli for simcse dataset.")
157
+
158
+ print("Loading negation dataset...")
159
+ negation_dataset = load_dataset("jinaai/negation-dataset", split="train")
160
+ negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12)
161
+ negation_train_dataset: Dataset = negation_dataset_dict["train"]
162
+ negation_eval_dataset: Dataset = negation_dataset_dict["test"]
163
+ print("Loaded negation dataset.")
164
+
165
+ train_dataset = DatasetDict({
166
+ "wikititles": wikititles_train_dataset,
167
+ "tatoeba": tatoeba_train_dataset,
168
+ "talks": talks_train_dataset,
169
+ "europarl": europarl_train_dataset,
170
+ "global_voices": global_voices_train_dataset,
171
+ "jw300": jw300_train_dataset,
172
+ "muse": muse_train_dataset,
173
+ "wikimatrix": wikimatrix_train_dataset,
174
+ "opensubtitles": opensubtitles_train_dataset,
175
+ "stackexchange": stackexchange_train_dataset,
176
+ "quora": quora_train_dataset,
177
+ "wikianswers_duplicates": wikianswers_duplicates_train_dataset,
178
+ "all_nli": all_nli_train_dataset,
179
+ "simple_wiki": simple_wiki_train_dataset,
180
+ "altlex": altlex_train_dataset,
181
+ "flickr30k_captions": flickr30k_captions_train_dataset,
182
+ "coco_captions": coco_captions_train_dataset,
183
+ "nli_for_simcse": nli_for_simcse_train_dataset,
184
+ "negation": negation_train_dataset,
185
+ })
186
+ eval_dataset = DatasetDict({
187
+ "wikititles": wikititles_eval_dataset,
188
+ "tatoeba": tatoeba_eval_dataset,
189
+ "talks": talks_eval_dataset,
190
+ "europarl": europarl_eval_dataset,
191
+ "global_voices": global_voices_eval_dataset,
192
+ "jw300": jw300_eval_dataset,
193
+ "muse": muse_eval_dataset,
194
+ "wikimatrix": wikimatrix_eval_dataset,
195
+ "opensubtitles": opensubtitles_eval_dataset,
196
+ "stackexchange": stackexchange_eval_dataset,
197
+ "quora": quora_eval_dataset,
198
+ "wikianswers_duplicates": wikianswers_duplicates_eval_dataset,
199
+ "all_nli": all_nli_eval_dataset,
200
+ "simple_wiki": simple_wiki_eval_dataset,
201
+ "altlex": altlex_eval_dataset,
202
+ "flickr30k_captions": flickr30k_captions_eval_dataset,
203
+ "coco_captions": coco_captions_eval_dataset,
204
+ "nli_for_simcse": nli_for_simcse_eval_dataset,
205
+ "negation": negation_eval_dataset,
206
+ })
207
+ print(train_dataset)
208
+
209
+ # 4. Define a loss function
210
+ loss = MultipleNegativesRankingLoss(model)
211
+ loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024])
212
+
213
+ # 5. (Optional) Specify training arguments
214
+ run_name = "static-similarity-mrl-multilingual-v1"
215
+ args = SentenceTransformerTrainingArguments(
216
+ # Required parameter:
217
+ output_dir=f"models/{run_name}",
218
+ # Optional training parameters:
219
+ num_train_epochs=1,
220
+ per_device_train_batch_size=2048,
221
+ per_device_eval_batch_size=2048,
222
+ learning_rate=2e-1,
223
+ warmup_ratio=0.1,
224
+ fp16=False, # Set to False if you get an error that your GPU can't run on FP16
225
+ bf16=True, # Set to True if you have a GPU that supports BF16
226
+ batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
227
+ multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
228
+ # Optional tracking/debugging parameters:
229
+ eval_strategy="steps",
230
+ eval_steps=1000,
231
+ save_strategy="steps",
232
+ save_steps=1000,
233
+ save_total_limit=2,
234
+ logging_steps=1000,
235
+ logging_first_step=True,
236
+ run_name=run_name, # Will be used in W&B if `wandb` is installed
237
+ )
238
+
239
+ # 6. Create a trainer & train
240
+ trainer = SentenceTransformerTrainer(
241
+ model=model,
242
+ args=args,
243
+ train_dataset=train_dataset,
244
+ eval_dataset=eval_dataset,
245
+ loss=loss,
246
+ )
247
+ trainer.train()
248
+
249
+ # 7. Save the trained model
250
+ model.save_pretrained(f"models/{run_name}/final")
251
+
252
+ # 8. (Optional) Push it to the Hugging Face Hub
253
+ model.push_to_hub(run_name, private=True)
254
+
255
+ if __name__ == "__main__":
256
+ main()