k1ngtai commited on
Commit
852ce7e
·
verified ·
1 Parent(s): ca36bdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -81
app.py CHANGED
@@ -1,84 +1,37 @@
1
- import os
2
- import torch
3
  import gradio as gr
4
- import time
5
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from flores200_codes import flores_codes
7
 
8
-
9
- def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {
12
- #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
13
- #'nllb-1.3B': 'facebook/nllb-200-1.3B',
14
- 'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
15
- #'nllb-3.3B': 'facebook/nllb-200-3.3B',
16
- }
17
-
18
- model_dict = {}
19
-
20
- for call_name, real_name in model_name_dict.items():
21
- print('\tLoading model: %s' % call_name)
22
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
23
- tokenizer = AutoTokenizer.from_pretrained(real_name)
24
- model_dict[call_name+'_model'] = model
25
- model_dict[call_name+'_tokenizer'] = tokenizer
26
-
27
- return model_dict
28
-
29
-
30
- def translation(source, target, text):
31
- if len(model_dict) == 2:
32
- model_name = 'nllb-distilled-600M'
33
-
34
- start_time = time.time()
35
- source = flores_codes[source]
36
- target = flores_codes[target]
37
-
38
- model = model_dict[model_name + '_model']
39
- tokenizer = model_dict[model_name + '_tokenizer']
40
-
41
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
42
- output = translator(text, max_length=400)
43
-
44
- end_time = time.time()
45
-
46
- output = output[0]['translation_text']
47
- result = {'inference_time': end_time - start_time,
48
- 'source': source,
49
- 'target': target,
50
- 'result': output}
51
- return result
52
-
53
-
54
- if __name__ == '__main__':
55
- print('\tinit models')
56
-
57
- global model_dict
58
-
59
- model_dict = load_models()
60
-
61
- # define gradio demo
62
- lang_codes = list(flores_codes.keys())
63
- #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
64
- inputs = [gr.Dropdown(lang_codes, default='English', label='Source'),
65
- gr.Dropdown(lang_codes, default='Shan', label='Target'),
66
- gr.Textbox(lines=5, label="Input text"),
67
- ]
68
-
69
- outputs = gr.outputs.JSON()
70
-
71
- title = "NLLB distilled 600M demo for shan"
72
-
73
- demo_status = "Demo is running on CPU & Thanks Geonmo"
74
- description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
75
- examples = [
76
- ['English', 'Shan', 'Hi. nice to meet you']
77
- ]
78
-
79
- gr.Interface(translation,
80
- inputs,
81
- outputs,
82
- title=title,
83
- description=description,
84
- ).launch()
 
 
 
1
  import gradio as gr
2
+ from nllb import translation, NLLB_EXAMPLES
 
3
  from flores200_codes import flores_codes
4
 
5
+ lang_codes = list(flores_codes.keys())
6
+
7
+ nllb_translate = gr.Interface(
8
+ fn=translation,
9
+ inputs=[
10
+ gr.Dropdown(
11
+ ["nllb-1.3B", "nllb-distilled-1.3B", "nllb-3.3B"],
12
+ label="Model",
13
+ value="nllb-distilled-1.3B",
14
+ ),
15
+ gr.Dropdown(
16
+ lang_codes,
17
+ label="Source language",
18
+ value="English",
19
+ ),
20
+ gr.Dropdown(
21
+ lang_codes,
22
+ label="Target language",
23
+ value="Shan",
24
+ ),
25
+ gr.Textbox(lines=5, label="Input text"),
26
+ ],
27
+ outputs="json",
28
+ examples=NLLB_EXAMPLES,
29
+ title="NLLB Translation Demo",
30
+ description="Translate text from one language to another.",
31
+ allow_flagging="never",
32
+ )
33
+
34
+ with gr.Blocks() as demo:
35
+ nllb_translate.render()
36
+
37
+ demo.launch()