import random import logging from datasets import load_dataset, Dataset, DatasetDict from sentence_transformers import ( SentenceTransformer, SentenceTransformerTrainer, SentenceTransformerTrainingArguments, SentenceTransformerModelCardData, ) from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers from sentence_transformers.models.StaticEmbedding import StaticEmbedding from transformers import AutoTokenizer logging.basicConfig( format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO ) random.seed(12) def load_train_eval_datasets(): """ Either load the train and eval datasets from disk or load them from the datasets library & save them to disk. Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training. """ try: train_dataset = DatasetDict.load_from_disk("datasets/train_dataset") eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset") return train_dataset, eval_dataset except FileNotFoundError: print("Loading wikititles dataset...") wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train") wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12) wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"] wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"] print("Loaded wikititles dataset.") print("Loading tatoeba dataset...") tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train") tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12) tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"] tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"] print("Loaded tatoeba dataset.") print("Loading talks dataset...") talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train") talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12) talks_train_dataset: Dataset = talks_dataset_dict["train"] talks_eval_dataset: Dataset = talks_dataset_dict["test"] print("Loaded talks dataset.") print("Loading europarl dataset...") europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]") europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12) europarl_train_dataset: Dataset = europarl_dataset_dict["train"] europarl_eval_dataset: Dataset = europarl_dataset_dict["test"] print("Loaded europarl dataset.") print("Loading global voices dataset...") global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train") global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12) global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"] global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"] print("Loaded global voices dataset.") print("Loading jw300 dataset...") jw300_dataset = load_dataset("sentence-transformers/parallel-sentences-jw300", "all", split="train") jw300_dataset_dict = jw300_dataset.train_test_split(test_size=10_000, seed=12) jw300_train_dataset: Dataset = jw300_dataset_dict["train"] jw300_eval_dataset: Dataset = jw300_dataset_dict["test"] print("Loaded jw300 dataset.") print("Loading muse dataset...") muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train") muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12) muse_train_dataset: Dataset = muse_dataset_dict["train"] muse_eval_dataset: Dataset = muse_dataset_dict["test"] print("Loaded muse dataset.") print("Loading wikimatrix dataset...") wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train") wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12) wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"] wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"] print("Loaded wikimatrix dataset.") print("Loading opensubtitles dataset...") opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]") opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12) opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"] opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"] print("Loaded opensubtitles dataset.") print("Loading stackexchange dataset...") stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train") stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12) stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"] stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"] print("Loaded stackexchange dataset.") print("Loading quora dataset...") quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train") quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12) quora_train_dataset: Dataset = quora_dataset_dict["train"] quora_eval_dataset: Dataset = quora_dataset_dict["test"] print("Loaded quora dataset.") print("Loading wikianswers duplicates dataset...") wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]") wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12) wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"] wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"] print("Loaded wikianswers duplicates dataset.") print("Loading all nli dataset...") all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train") all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev") print("Loaded all nli dataset.") print("Loading simple wiki dataset...") simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train") simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12) simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"] simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"] print("Loaded simple wiki dataset.") print("Loading altlex dataset...") altlex_dataset = load_dataset("sentence-transformers/altlex", split="train") altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12) altlex_train_dataset: Dataset = altlex_dataset_dict["train"] altlex_eval_dataset: Dataset = altlex_dataset_dict["test"] print("Loaded altlex dataset.") print("Loading flickr30k captions dataset...") flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train") flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12) flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"] flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"] print("Loaded flickr30k captions dataset.") print("Loading coco captions dataset...") coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train") coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12) coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"] coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"] print("Loaded coco captions dataset.") print("Loading nli for simcse dataset...") nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train") nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12) nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"] nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"] print("Loaded nli for simcse dataset.") print("Loading negation dataset...") negation_dataset = load_dataset("jinaai/negation-dataset", split="train") negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12) negation_train_dataset: Dataset = negation_dataset_dict["train"] negation_eval_dataset: Dataset = negation_dataset_dict["test"] print("Loaded negation dataset.") train_dataset = DatasetDict({ "wikititles": wikititles_train_dataset, "tatoeba": tatoeba_train_dataset, "talks": talks_train_dataset, "europarl": europarl_train_dataset, "global_voices": global_voices_train_dataset, "jw300": jw300_train_dataset, "muse": muse_train_dataset, "wikimatrix": wikimatrix_train_dataset, "opensubtitles": opensubtitles_train_dataset, "stackexchange": stackexchange_train_dataset, "quora": quora_train_dataset, "wikianswers_duplicates": wikianswers_duplicates_train_dataset, "all_nli": all_nli_train_dataset, "simple_wiki": simple_wiki_train_dataset, "altlex": altlex_train_dataset, "flickr30k_captions": flickr30k_captions_train_dataset, "coco_captions": coco_captions_train_dataset, "nli_for_simcse": nli_for_simcse_train_dataset, "negation": negation_train_dataset, }) eval_dataset = DatasetDict({ "wikititles": wikititles_eval_dataset, "tatoeba": tatoeba_eval_dataset, "talks": talks_eval_dataset, "europarl": europarl_eval_dataset, "global_voices": global_voices_eval_dataset, "jw300": jw300_eval_dataset, "muse": muse_eval_dataset, "wikimatrix": wikimatrix_eval_dataset, "opensubtitles": opensubtitles_eval_dataset, "stackexchange": stackexchange_eval_dataset, "quora": quora_eval_dataset, "wikianswers_duplicates": wikianswers_duplicates_eval_dataset, "all_nli": all_nli_eval_dataset, "simple_wiki": simple_wiki_eval_dataset, "altlex": altlex_eval_dataset, "flickr30k_captions": flickr30k_captions_eval_dataset, "coco_captions": coco_captions_eval_dataset, "nli_for_simcse": nli_for_simcse_eval_dataset, "negation": negation_eval_dataset, }) train_dataset.save_to_disk("datasets/train_dataset") eval_dataset.save_to_disk("datasets/eval_dataset") # The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk # So we're calling quit() here. Running the script again will load the datasets from disk. quit() def main(): # 1. Load a model to finetune with 2. (Optional) model card data static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024) model = SentenceTransformer( modules=[static_embedding], model_card_data=SentenceTransformerModelCardData( license="apache-2.0", model_name="Static Embeddings with BERT Multilingual uncased tokenizer finetuned on various datasets", ), ) # 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL) train_dataset, eval_dataset = load_train_eval_datasets() print(train_dataset) # 4. Define a loss function loss = MultipleNegativesRankingLoss(model) loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024]) # 5. (Optional) Specify training arguments run_name = "static-similarity-mrl-multilingual-v1" args = SentenceTransformerTrainingArguments( # Required parameter: output_dir=f"models/{run_name}", # Optional training parameters: num_train_epochs=1, per_device_train_batch_size=2048, per_device_eval_batch_size=2048, learning_rate=2e-1, warmup_ratio=0.1, fp16=False, # Set to False if you get an error that your GPU can't run on FP16 bf16=True, # Set to True if you have a GPU that supports BF16 batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL, # Optional tracking/debugging parameters: eval_strategy="steps", eval_steps=1000, save_strategy="steps", save_steps=1000, save_total_limit=2, logging_steps=1000, logging_first_step=True, run_name=run_name, # Will be used in W&B if `wandb` is installed ) # 6. Create a trainer & train trainer = SentenceTransformerTrainer( model=model, args=args, train_dataset=train_dataset, eval_dataset=eval_dataset, loss=loss, ) trainer.train() # 7. Save the trained model model.save_pretrained(f"models/{run_name}/final") # 8. (Optional) Push it to the Hugging Face Hub model.push_to_hub(run_name, private=True) if __name__ == "__main__": main()