cointegrated commited on
Commit
832c0d0
·
1 Parent(s): b943c3c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +61 -0
README.md CHANGED
@@ -11,3 +11,64 @@ fine-tuned for translating between Tyvan and Russian languages using the dataset
11
 
12
  Here is [a post](https://cointegrated.medium.com/a37fc706b865) about how it was trained.
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  Here is [a post](https://cointegrated.medium.com/a37fc706b865) about how it was trained.
13
 
14
+ How to use the model:
15
+
16
+ ```Python
17
+ # the version of transformers is important!
18
+ !pip install sentencepiece transformers==4.33
19
+ import torch
20
+ from transformers import NllbTokenizer, AutoModelForSeq2SeqLM
21
+
22
+ def fix_tokenizer(tokenizer, new_lang='tyv_Cyrl'):
23
+ """ Add a new language token to the tokenizer vocabulary (this should be done each time after its initialization) """
24
+ old_len = len(tokenizer) - int(new_lang in tokenizer.added_tokens_encoder)
25
+ tokenizer.lang_code_to_id[new_lang] = old_len-1
26
+ tokenizer.id_to_lang_code[old_len-1] = new_lang
27
+ # always move "mask" to the last position
28
+ tokenizer.fairseq_tokens_to_ids["<mask>"] = len(tokenizer.sp_model) + len(tokenizer.lang_code_to_id) + tokenizer.fairseq_offset
29
+
30
+ tokenizer.fairseq_tokens_to_ids.update(tokenizer.lang_code_to_id)
31
+ tokenizer.fairseq_ids_to_tokens = {v: k for k, v in tokenizer.fairseq_tokens_to_ids.items()}
32
+ if new_lang not in tokenizer._additional_special_tokens:
33
+ tokenizer._additional_special_tokens.append(new_lang)
34
+ # clear the added token encoder; otherwise a new token may end up there by mistake
35
+ tokenizer.added_tokens_encoder = {}
36
+ tokenizer.added_tokens_decoder = {}
37
+
38
+ MODEL_URL = "slone/nllb-rus-tyv-v1"
39
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_URL)
40
+ tokenizer = NllbTokenizer.from_pretrained(MODEL_URL)
41
+ fix_tokenizer(tokenizer)
42
+
43
+ def translate(
44
+ text,
45
+ model,
46
+ tokenizer,
47
+ src_lang='rus_Cyrl',
48
+ tgt_lang='tyv_Cyrl',
49
+ max_length='auto',
50
+ num_beams=4,
51
+ n_out=None,
52
+ **kwargs
53
+ ):
54
+ tokenizer.src_lang = src_lang
55
+ encoded = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
56
+ if max_length == 'auto':
57
+ max_length = int(32 + 2.0 * encoded.input_ids.shape[1])
58
+ model.eval()
59
+ generated_tokens = model.generate(
60
+ **encoded.to(model.device),
61
+ forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
62
+ max_length=max_length,
63
+ num_beams=num_beams,
64
+ num_return_sequences=n_out or 1,
65
+ **kwargs
66
+ )
67
+ out = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
68
+ if isinstance(text, str) and n_out is None:
69
+ return out[0]
70
+ return
71
+
72
+ translate("красная птица", model=model, tokenizer=tokenizer)
73
+ # 'кызыл куш'
74
+ ```