Tom Aarsen
Clarify why we call quit()
bb98c75
import random
import logging
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers import (
SentenceTransformer,
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
SentenceTransformerModelCardData,
)
from sentence_transformers.losses import MatryoshkaLoss, MultipleNegativesRankingLoss
from sentence_transformers.training_args import BatchSamplers, MultiDatasetBatchSamplers
from sentence_transformers.models.StaticEmbedding import StaticEmbedding
from transformers import AutoTokenizer
logging.basicConfig(
format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO
)
random.seed(12)
def load_train_eval_datasets():
"""
Either load the train and eval datasets from disk or load them from the datasets library & save them to disk.
Upon saving to disk, we quit() to ensure that the datasets are not loaded into memory before training.
"""
try:
train_dataset = DatasetDict.load_from_disk("datasets/train_dataset")
eval_dataset = DatasetDict.load_from_disk("datasets/eval_dataset")
return train_dataset, eval_dataset
except FileNotFoundError:
print("Loading wikititles dataset...")
wikititles_dataset = load_dataset("sentence-transformers/parallel-sentences-wikititles", split="train")
wikititles_dataset_dict = wikititles_dataset.train_test_split(test_size=10_000, seed=12)
wikititles_train_dataset: Dataset = wikititles_dataset_dict["train"]
wikititles_eval_dataset: Dataset = wikititles_dataset_dict["test"]
print("Loaded wikititles dataset.")
print("Loading tatoeba dataset...")
tatoeba_dataset = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "all", split="train")
tatoeba_dataset_dict = tatoeba_dataset.train_test_split(test_size=10_000, seed=12)
tatoeba_train_dataset: Dataset = tatoeba_dataset_dict["train"]
tatoeba_eval_dataset: Dataset = tatoeba_dataset_dict["test"]
print("Loaded tatoeba dataset.")
print("Loading talks dataset...")
talks_dataset = load_dataset("sentence-transformers/parallel-sentences-talks", "all", split="train")
talks_dataset_dict = talks_dataset.train_test_split(test_size=10_000, seed=12)
talks_train_dataset: Dataset = talks_dataset_dict["train"]
talks_eval_dataset: Dataset = talks_dataset_dict["test"]
print("Loaded talks dataset.")
print("Loading europarl dataset...")
europarl_dataset = load_dataset("sentence-transformers/parallel-sentences-europarl", "all", split="train[:5000000]")
europarl_dataset_dict = europarl_dataset.train_test_split(test_size=10_000, seed=12)
europarl_train_dataset: Dataset = europarl_dataset_dict["train"]
europarl_eval_dataset: Dataset = europarl_dataset_dict["test"]
print("Loaded europarl dataset.")
print("Loading global voices dataset...")
global_voices_dataset = load_dataset("sentence-transformers/parallel-sentences-global-voices", "all", split="train")
global_voices_dataset_dict = global_voices_dataset.train_test_split(test_size=10_000, seed=12)
global_voices_train_dataset: Dataset = global_voices_dataset_dict["train"]
global_voices_eval_dataset: Dataset = global_voices_dataset_dict["test"]
print("Loaded global voices dataset.")
print("Loading jw300 dataset...")
jw300_dataset = load_dataset("sentence-transformers/parallel-sentences-jw300", "all", split="train")
jw300_dataset_dict = jw300_dataset.train_test_split(test_size=10_000, seed=12)
jw300_train_dataset: Dataset = jw300_dataset_dict["train"]
jw300_eval_dataset: Dataset = jw300_dataset_dict["test"]
print("Loaded jw300 dataset.")
print("Loading muse dataset...")
muse_dataset = load_dataset("sentence-transformers/parallel-sentences-muse", split="train")
muse_dataset_dict = muse_dataset.train_test_split(test_size=10_000, seed=12)
muse_train_dataset: Dataset = muse_dataset_dict["train"]
muse_eval_dataset: Dataset = muse_dataset_dict["test"]
print("Loaded muse dataset.")
print("Loading wikimatrix dataset...")
wikimatrix_dataset = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "all", split="train")
wikimatrix_dataset_dict = wikimatrix_dataset.train_test_split(test_size=10_000, seed=12)
wikimatrix_train_dataset: Dataset = wikimatrix_dataset_dict["train"]
wikimatrix_eval_dataset: Dataset = wikimatrix_dataset_dict["test"]
print("Loaded wikimatrix dataset.")
print("Loading opensubtitles dataset...")
opensubtitles_dataset = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "all", split="train[:5000000]")
opensubtitles_dataset_dict = opensubtitles_dataset.train_test_split(test_size=10_000, seed=12)
opensubtitles_train_dataset: Dataset = opensubtitles_dataset_dict["train"]
opensubtitles_eval_dataset: Dataset = opensubtitles_dataset_dict["test"]
print("Loaded opensubtitles dataset.")
print("Loading stackexchange dataset...")
stackexchange_dataset = load_dataset("sentence-transformers/stackexchange-duplicates", "post-post-pair", split="train")
stackexchange_dataset_dict = stackexchange_dataset.train_test_split(test_size=10_000, seed=12)
stackexchange_train_dataset: Dataset = stackexchange_dataset_dict["train"]
stackexchange_eval_dataset: Dataset = stackexchange_dataset_dict["test"]
print("Loaded stackexchange dataset.")
print("Loading quora dataset...")
quora_dataset = load_dataset("sentence-transformers/quora-duplicates", "triplet", split="train")
quora_dataset_dict = quora_dataset.train_test_split(test_size=10_000, seed=12)
quora_train_dataset: Dataset = quora_dataset_dict["train"]
quora_eval_dataset: Dataset = quora_dataset_dict["test"]
print("Loaded quora dataset.")
print("Loading wikianswers duplicates dataset...")
wikianswers_duplicates_dataset = load_dataset("sentence-transformers/wikianswers-duplicates", split="train[:10000000]")
wikianswers_duplicates_dict = wikianswers_duplicates_dataset.train_test_split(test_size=10_000, seed=12)
wikianswers_duplicates_train_dataset: Dataset = wikianswers_duplicates_dict["train"]
wikianswers_duplicates_eval_dataset: Dataset = wikianswers_duplicates_dict["test"]
print("Loaded wikianswers duplicates dataset.")
print("Loading all nli dataset...")
all_nli_train_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="train")
all_nli_eval_dataset = load_dataset("sentence-transformers/all-nli", "triplet", split="dev")
print("Loaded all nli dataset.")
print("Loading simple wiki dataset...")
simple_wiki_dataset = load_dataset("sentence-transformers/simple-wiki", split="train")
simple_wiki_dataset_dict = simple_wiki_dataset.train_test_split(test_size=10_000, seed=12)
simple_wiki_train_dataset: Dataset = simple_wiki_dataset_dict["train"]
simple_wiki_eval_dataset: Dataset = simple_wiki_dataset_dict["test"]
print("Loaded simple wiki dataset.")
print("Loading altlex dataset...")
altlex_dataset = load_dataset("sentence-transformers/altlex", split="train")
altlex_dataset_dict = altlex_dataset.train_test_split(test_size=10_000, seed=12)
altlex_train_dataset: Dataset = altlex_dataset_dict["train"]
altlex_eval_dataset: Dataset = altlex_dataset_dict["test"]
print("Loaded altlex dataset.")
print("Loading flickr30k captions dataset...")
flickr30k_captions_dataset = load_dataset("sentence-transformers/flickr30k-captions", split="train")
flickr30k_captions_dataset_dict = flickr30k_captions_dataset.train_test_split(test_size=10_000, seed=12)
flickr30k_captions_train_dataset: Dataset = flickr30k_captions_dataset_dict["train"]
flickr30k_captions_eval_dataset: Dataset = flickr30k_captions_dataset_dict["test"]
print("Loaded flickr30k captions dataset.")
print("Loading coco captions dataset...")
coco_captions_dataset = load_dataset("sentence-transformers/coco-captions", split="train")
coco_captions_dataset_dict = coco_captions_dataset.train_test_split(test_size=10_000, seed=12)
coco_captions_train_dataset: Dataset = coco_captions_dataset_dict["train"]
coco_captions_eval_dataset: Dataset = coco_captions_dataset_dict["test"]
print("Loaded coco captions dataset.")
print("Loading nli for simcse dataset...")
nli_for_simcse_dataset = load_dataset("sentence-transformers/nli-for-simcse", "triplet", split="train")
nli_for_simcse_dataset_dict = nli_for_simcse_dataset.train_test_split(test_size=10_000, seed=12)
nli_for_simcse_train_dataset: Dataset = nli_for_simcse_dataset_dict["train"]
nli_for_simcse_eval_dataset: Dataset = nli_for_simcse_dataset_dict["test"]
print("Loaded nli for simcse dataset.")
print("Loading negation dataset...")
negation_dataset = load_dataset("jinaai/negation-dataset", split="train")
negation_dataset_dict = negation_dataset.train_test_split(test_size=100, seed=12)
negation_train_dataset: Dataset = negation_dataset_dict["train"]
negation_eval_dataset: Dataset = negation_dataset_dict["test"]
print("Loaded negation dataset.")
train_dataset = DatasetDict({
"wikititles": wikititles_train_dataset,
"tatoeba": tatoeba_train_dataset,
"talks": talks_train_dataset,
"europarl": europarl_train_dataset,
"global_voices": global_voices_train_dataset,
"jw300": jw300_train_dataset,
"muse": muse_train_dataset,
"wikimatrix": wikimatrix_train_dataset,
"opensubtitles": opensubtitles_train_dataset,
"stackexchange": stackexchange_train_dataset,
"quora": quora_train_dataset,
"wikianswers_duplicates": wikianswers_duplicates_train_dataset,
"all_nli": all_nli_train_dataset,
"simple_wiki": simple_wiki_train_dataset,
"altlex": altlex_train_dataset,
"flickr30k_captions": flickr30k_captions_train_dataset,
"coco_captions": coco_captions_train_dataset,
"nli_for_simcse": nli_for_simcse_train_dataset,
"negation": negation_train_dataset,
})
eval_dataset = DatasetDict({
"wikititles": wikititles_eval_dataset,
"tatoeba": tatoeba_eval_dataset,
"talks": talks_eval_dataset,
"europarl": europarl_eval_dataset,
"global_voices": global_voices_eval_dataset,
"jw300": jw300_eval_dataset,
"muse": muse_eval_dataset,
"wikimatrix": wikimatrix_eval_dataset,
"opensubtitles": opensubtitles_eval_dataset,
"stackexchange": stackexchange_eval_dataset,
"quora": quora_eval_dataset,
"wikianswers_duplicates": wikianswers_duplicates_eval_dataset,
"all_nli": all_nli_eval_dataset,
"simple_wiki": simple_wiki_eval_dataset,
"altlex": altlex_eval_dataset,
"flickr30k_captions": flickr30k_captions_eval_dataset,
"coco_captions": coco_captions_eval_dataset,
"nli_for_simcse": nli_for_simcse_eval_dataset,
"negation": negation_eval_dataset,
})
train_dataset.save_to_disk("datasets/train_dataset")
eval_dataset.save_to_disk("datasets/eval_dataset")
# The `train_test_split` calls have put a lot of the datasets in memory, while we want it to just be on disk
# So we're calling quit() here. Running the script again will load the datasets from disk.
quit()
def main():
# 1. Load a model to finetune with 2. (Optional) model card data
static_embedding = StaticEmbedding(AutoTokenizer.from_pretrained("google-bert/bert-base-multilingual-uncased"), embedding_dim=1024)
model = SentenceTransformer(
modules=[static_embedding],
model_card_data=SentenceTransformerModelCardData(
license="apache-2.0",
model_name="Static Embeddings with BERT Multilingual uncased tokenizer finetuned on various datasets",
),
)
# 3. Set up training & evaluation datasets - each dataset is trained with MNRL (with MRL)
train_dataset, eval_dataset = load_train_eval_datasets()
print(train_dataset)
# 4. Define a loss function
loss = MultipleNegativesRankingLoss(model)
loss = MatryoshkaLoss(model, loss, matryoshka_dims=[32, 64, 128, 256, 512, 1024])
# 5. (Optional) Specify training arguments
run_name = "static-similarity-mrl-multilingual-v1"
args = SentenceTransformerTrainingArguments(
# Required parameter:
output_dir=f"models/{run_name}",
# Optional training parameters:
num_train_epochs=1,
per_device_train_batch_size=2048,
per_device_eval_batch_size=2048,
learning_rate=2e-1,
warmup_ratio=0.1,
fp16=False, # Set to False if you get an error that your GPU can't run on FP16
bf16=True, # Set to True if you have a GPU that supports BF16
batch_sampler=BatchSamplers.NO_DUPLICATES, # MultipleNegativesRankingLoss benefits from no duplicate samples in a batch
multi_dataset_batch_sampler=MultiDatasetBatchSamplers.PROPORTIONAL,
# Optional tracking/debugging parameters:
eval_strategy="steps",
eval_steps=1000,
save_strategy="steps",
save_steps=1000,
save_total_limit=2,
logging_steps=1000,
logging_first_step=True,
run_name=run_name, # Will be used in W&B if `wandb` is installed
)
# 6. Create a trainer & train
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=loss,
)
trainer.train()
# 7. Save the trained model
model.save_pretrained(f"models/{run_name}/final")
# 8. (Optional) Push it to the Hugging Face Hub
model.push_to_hub(run_name, private=True)
if __name__ == "__main__":
main()