k1ngtai commited on
Commit
13de1af
·
verified ·
1 Parent(s): 852ce7e

Create nllb.py

Browse files
Files changed (1) hide show
  1. nllb.py +62 -0
nllb.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
+ from flores200_codes import flores_codes
3
+
4
+ model_dict = {}
5
+
6
+
7
+ def load_models(model_name: str):
8
+ # build model and tokenizer
9
+ model_name_dict = {
10
+ "nllb-1.3B": "facebook/nllb-200-1.3B",
11
+ "nllb-distilled-1.3B": "facebook/nllb-200-distilled-1.3B",
12
+ "nllb-3.3B": "facebook/nllb-200-3.3B",
13
+ }[model_name]
14
+
15
+ print("\tLoading model: %s" % model_name)
16
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name_dict)
17
+ tokenizer = AutoTokenizer.from_pretrained(model_name_dict)
18
+ model_dict[model_name + "_model"] = model
19
+ model_dict[model_name + "_tokenizer"] = tokenizer
20
+
21
+ return model_dict
22
+
23
+
24
+ def translation(model_name: str, source, target, text: str):
25
+
26
+ model_dict = load_models(model_name)
27
+
28
+ source = flores_codes[source]
29
+ target = flores_codes[target]
30
+
31
+ model = model_dict[model_name + "_model"]
32
+ tokenizer = model_dict[model_name + "_tokenizer"]
33
+
34
+ translator = pipeline(
35
+ "translation",
36
+ model=model,
37
+ tokenizer=tokenizer,
38
+ src_lang=source,
39
+ tgt_lang=target,
40
+ )
41
+ output = translator(text, max_length=400)
42
+
43
+ output = output[0]["translation_text"]
44
+ result = {
45
+ "source": source,
46
+ "target": target,
47
+ "result": output,
48
+ }
49
+
50
+ return result
51
+
52
+
53
+ NLLB_EXAMPLES = [
54
+ ["nllb-distilled-1.3B", "English", "Shan", "Hello, how are you today?"],
55
+ ["nllb-distilled-1.3B", "Shan", "English", "မႂ်ႇသုင်ၶႃႈ ယူႇလီယူႇၶႃႈၼေႃႈ"],
56
+ [
57
+ "nllb-distilled-1.3B",
58
+ "English",
59
+ "Shan",
60
+ "Forming Myanmar’s New Political System Will Remain an Ideal but Never in Practicality",
61
+ ],
62
+ ]