Commit
·
e5e7d52
1
Parent(s):
2d42acb
Saving weights and logs of epoch 1
Browse files- .gitattributes +6 -0
- added_tokens.json +1609 -0
- config.json +154 -0
- distil_whisper/__init__.py +21 -0
- distil_whisper/__pycache__/__init__.cpython-310.pyc +3 -0
- distil_whisper/__pycache__/layers.cpython-310.pyc +3 -0
- distil_whisper/__pycache__/modeling_flax_whisper.cpython-310.pyc +3 -0
- distil_whisper/__pycache__/partitioner.cpython-310.pyc +3 -0
- distil_whisper/__pycache__/pipeline.cpython-310.pyc +3 -0
- distil_whisper/__pycache__/train_state.cpython-310.pyc +3 -0
- distil_whisper/layers.py +1338 -0
- distil_whisper/modeling_flax_whisper.py +2136 -0
- distil_whisper/partitioner.py +965 -0
- distil_whisper/pipeline.py +525 -0
- distil_whisper/train_state.py +118 -0
- events.out.tfevents.1696323477.t1v-n-4eccb2d4-w-0.2889203.0.v2 +3 -0
- flax_model.msgpack +3 -0
- generation_config.json +319 -0
- merges.txt +0 -0
- normalizer.json +1742 -0
- preprocessor_config.json +14 -0
- run.sh +33 -0
- run_finetuning.py +1111 -0
- special_tokens_map.json +115 -0
- tokenizer.json +0 -0
- tokenizer_config.json +0 -0
- vocab.json +0 -0
.gitattributes
CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
distil_whisper/__pycache__/__init__.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
37 |
+
distil_whisper/__pycache__/layers.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
38 |
+
distil_whisper/__pycache__/modeling_flax_whisper.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
39 |
+
distil_whisper/__pycache__/partitioner.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
40 |
+
distil_whisper/__pycache__/pipeline.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
41 |
+
distil_whisper/__pycache__/train_state.cpython-310.pyc filter=lfs diff=lfs merge=lfs -text
|
added_tokens.json
ADDED
@@ -0,0 +1,1609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"<|0.00|>": 50364,
|
3 |
+
"<|0.02|>": 50365,
|
4 |
+
"<|0.04|>": 50366,
|
5 |
+
"<|0.06|>": 50367,
|
6 |
+
"<|0.08|>": 50368,
|
7 |
+
"<|0.10|>": 50369,
|
8 |
+
"<|0.12|>": 50370,
|
9 |
+
"<|0.14|>": 50371,
|
10 |
+
"<|0.16|>": 50372,
|
11 |
+
"<|0.18|>": 50373,
|
12 |
+
"<|0.20|>": 50374,
|
13 |
+
"<|0.22|>": 50375,
|
14 |
+
"<|0.24|>": 50376,
|
15 |
+
"<|0.26|>": 50377,
|
16 |
+
"<|0.28|>": 50378,
|
17 |
+
"<|0.30|>": 50379,
|
18 |
+
"<|0.32|>": 50380,
|
19 |
+
"<|0.34|>": 50381,
|
20 |
+
"<|0.36|>": 50382,
|
21 |
+
"<|0.38|>": 50383,
|
22 |
+
"<|0.40|>": 50384,
|
23 |
+
"<|0.42|>": 50385,
|
24 |
+
"<|0.44|>": 50386,
|
25 |
+
"<|0.46|>": 50387,
|
26 |
+
"<|0.48|>": 50388,
|
27 |
+
"<|0.50|>": 50389,
|
28 |
+
"<|0.52|>": 50390,
|
29 |
+
"<|0.54|>": 50391,
|
30 |
+
"<|0.56|>": 50392,
|
31 |
+
"<|0.58|>": 50393,
|
32 |
+
"<|0.60|>": 50394,
|
33 |
+
"<|0.62|>": 50395,
|
34 |
+
"<|0.64|>": 50396,
|
35 |
+
"<|0.66|>": 50397,
|
36 |
+
"<|0.68|>": 50398,
|
37 |
+
"<|0.70|>": 50399,
|
38 |
+
"<|0.72|>": 50400,
|
39 |
+
"<|0.74|>": 50401,
|
40 |
+
"<|0.76|>": 50402,
|
41 |
+
"<|0.78|>": 50403,
|
42 |
+
"<|0.80|>": 50404,
|
43 |
+
"<|0.82|>": 50405,
|
44 |
+
"<|0.84|>": 50406,
|
45 |
+
"<|0.86|>": 50407,
|
46 |
+
"<|0.88|>": 50408,
|
47 |
+
"<|0.90|>": 50409,
|
48 |
+
"<|0.92|>": 50410,
|
49 |
+
"<|0.94|>": 50411,
|
50 |
+
"<|0.96|>": 50412,
|
51 |
+
"<|0.98|>": 50413,
|
52 |
+
"<|1.00|>": 50414,
|
53 |
+
"<|1.02|>": 50415,
|
54 |
+
"<|1.04|>": 50416,
|
55 |
+
"<|1.06|>": 50417,
|
56 |
+
"<|1.08|>": 50418,
|
57 |
+
"<|1.10|>": 50419,
|
58 |
+
"<|1.12|>": 50420,
|
59 |
+
"<|1.14|>": 50421,
|
60 |
+
"<|1.16|>": 50422,
|
61 |
+
"<|1.18|>": 50423,
|
62 |
+
"<|1.20|>": 50424,
|
63 |
+
"<|1.22|>": 50425,
|
64 |
+
"<|1.24|>": 50426,
|
65 |
+
"<|1.26|>": 50427,
|
66 |
+
"<|1.28|>": 50428,
|
67 |
+
"<|1.30|>": 50429,
|
68 |
+
"<|1.32|>": 50430,
|
69 |
+
"<|1.34|>": 50431,
|
70 |
+
"<|1.36|>": 50432,
|
71 |
+
"<|1.38|>": 50433,
|
72 |
+
"<|1.40|>": 50434,
|
73 |
+
"<|1.42|>": 50435,
|
74 |
+
"<|1.44|>": 50436,
|
75 |
+
"<|1.46|>": 50437,
|
76 |
+
"<|1.48|>": 50438,
|
77 |
+
"<|1.50|>": 50439,
|
78 |
+
"<|1.52|>": 50440,
|
79 |
+
"<|1.54|>": 50441,
|
80 |
+
"<|1.56|>": 50442,
|
81 |
+
"<|1.58|>": 50443,
|
82 |
+
"<|1.60|>": 50444,
|
83 |
+
"<|1.62|>": 50445,
|
84 |
+
"<|1.64|>": 50446,
|
85 |
+
"<|1.66|>": 50447,
|
86 |
+
"<|1.68|>": 50448,
|
87 |
+
"<|1.70|>": 50449,
|
88 |
+
"<|1.72|>": 50450,
|
89 |
+
"<|1.74|>": 50451,
|
90 |
+
"<|1.76|>": 50452,
|
91 |
+
"<|1.78|>": 50453,
|
92 |
+
"<|1.80|>": 50454,
|
93 |
+
"<|1.82|>": 50455,
|
94 |
+
"<|1.84|>": 50456,
|
95 |
+
"<|1.86|>": 50457,
|
96 |
+
"<|1.88|>": 50458,
|
97 |
+
"<|1.90|>": 50459,
|
98 |
+
"<|1.92|>": 50460,
|
99 |
+
"<|1.94|>": 50461,
|
100 |
+
"<|1.96|>": 50462,
|
101 |
+
"<|1.98|>": 50463,
|
102 |
+
"<|10.00|>": 50864,
|
103 |
+
"<|10.02|>": 50865,
|
104 |
+
"<|10.04|>": 50866,
|
105 |
+
"<|10.06|>": 50867,
|
106 |
+
"<|10.08|>": 50868,
|
107 |
+
"<|10.10|>": 50869,
|
108 |
+
"<|10.12|>": 50870,
|
109 |
+
"<|10.14|>": 50871,
|
110 |
+
"<|10.16|>": 50872,
|
111 |
+
"<|10.18|>": 50873,
|
112 |
+
"<|10.20|>": 50874,
|
113 |
+
"<|10.22|>": 50875,
|
114 |
+
"<|10.24|>": 50876,
|
115 |
+
"<|10.26|>": 50877,
|
116 |
+
"<|10.28|>": 50878,
|
117 |
+
"<|10.30|>": 50879,
|
118 |
+
"<|10.32|>": 50880,
|
119 |
+
"<|10.34|>": 50881,
|
120 |
+
"<|10.36|>": 50882,
|
121 |
+
"<|10.38|>": 50883,
|
122 |
+
"<|10.40|>": 50884,
|
123 |
+
"<|10.42|>": 50885,
|
124 |
+
"<|10.44|>": 50886,
|
125 |
+
"<|10.46|>": 50887,
|
126 |
+
"<|10.48|>": 50888,
|
127 |
+
"<|10.50|>": 50889,
|
128 |
+
"<|10.52|>": 50890,
|
129 |
+
"<|10.54|>": 50891,
|
130 |
+
"<|10.56|>": 50892,
|
131 |
+
"<|10.58|>": 50893,
|
132 |
+
"<|10.60|>": 50894,
|
133 |
+
"<|10.62|>": 50895,
|
134 |
+
"<|10.64|>": 50896,
|
135 |
+
"<|10.66|>": 50897,
|
136 |
+
"<|10.68|>": 50898,
|
137 |
+
"<|10.70|>": 50899,
|
138 |
+
"<|10.72|>": 50900,
|
139 |
+
"<|10.74|>": 50901,
|
140 |
+
"<|10.76|>": 50902,
|
141 |
+
"<|10.78|>": 50903,
|
142 |
+
"<|10.80|>": 50904,
|
143 |
+
"<|10.82|>": 50905,
|
144 |
+
"<|10.84|>": 50906,
|
145 |
+
"<|10.86|>": 50907,
|
146 |
+
"<|10.88|>": 50908,
|
147 |
+
"<|10.90|>": 50909,
|
148 |
+
"<|10.92|>": 50910,
|
149 |
+
"<|10.94|>": 50911,
|
150 |
+
"<|10.96|>": 50912,
|
151 |
+
"<|10.98|>": 50913,
|
152 |
+
"<|11.00|>": 50914,
|
153 |
+
"<|11.02|>": 50915,
|
154 |
+
"<|11.04|>": 50916,
|
155 |
+
"<|11.06|>": 50917,
|
156 |
+
"<|11.08|>": 50918,
|
157 |
+
"<|11.10|>": 50919,
|
158 |
+
"<|11.12|>": 50920,
|
159 |
+
"<|11.14|>": 50921,
|
160 |
+
"<|11.16|>": 50922,
|
161 |
+
"<|11.18|>": 50923,
|
162 |
+
"<|11.20|>": 50924,
|
163 |
+
"<|11.22|>": 50925,
|
164 |
+
"<|11.24|>": 50926,
|
165 |
+
"<|11.26|>": 50927,
|
166 |
+
"<|11.28|>": 50928,
|
167 |
+
"<|11.30|>": 50929,
|
168 |
+
"<|11.32|>": 50930,
|
169 |
+
"<|11.34|>": 50931,
|
170 |
+
"<|11.36|>": 50932,
|
171 |
+
"<|11.38|>": 50933,
|
172 |
+
"<|11.40|>": 50934,
|
173 |
+
"<|11.42|>": 50935,
|
174 |
+
"<|11.44|>": 50936,
|
175 |
+
"<|11.46|>": 50937,
|
176 |
+
"<|11.48|>": 50938,
|
177 |
+
"<|11.50|>": 50939,
|
178 |
+
"<|11.52|>": 50940,
|
179 |
+
"<|11.54|>": 50941,
|
180 |
+
"<|11.56|>": 50942,
|
181 |
+
"<|11.58|>": 50943,
|
182 |
+
"<|11.60|>": 50944,
|
183 |
+
"<|11.62|>": 50945,
|
184 |
+
"<|11.64|>": 50946,
|
185 |
+
"<|11.66|>": 50947,
|
186 |
+
"<|11.68|>": 50948,
|
187 |
+
"<|11.70|>": 50949,
|
188 |
+
"<|11.72|>": 50950,
|
189 |
+
"<|11.74|>": 50951,
|
190 |
+
"<|11.76|>": 50952,
|
191 |
+
"<|11.78|>": 50953,
|
192 |
+
"<|11.80|>": 50954,
|
193 |
+
"<|11.82|>": 50955,
|
194 |
+
"<|11.84|>": 50956,
|
195 |
+
"<|11.86|>": 50957,
|
196 |
+
"<|11.88|>": 50958,
|
197 |
+
"<|11.90|>": 50959,
|
198 |
+
"<|11.92|>": 50960,
|
199 |
+
"<|11.94|>": 50961,
|
200 |
+
"<|11.96|>": 50962,
|
201 |
+
"<|11.98|>": 50963,
|
202 |
+
"<|12.00|>": 50964,
|
203 |
+
"<|12.02|>": 50965,
|
204 |
+
"<|12.04|>": 50966,
|
205 |
+
"<|12.06|>": 50967,
|
206 |
+
"<|12.08|>": 50968,
|
207 |
+
"<|12.10|>": 50969,
|
208 |
+
"<|12.12|>": 50970,
|
209 |
+
"<|12.14|>": 50971,
|
210 |
+
"<|12.16|>": 50972,
|
211 |
+
"<|12.18|>": 50973,
|
212 |
+
"<|12.20|>": 50974,
|
213 |
+
"<|12.22|>": 50975,
|
214 |
+
"<|12.24|>": 50976,
|
215 |
+
"<|12.26|>": 50977,
|
216 |
+
"<|12.28|>": 50978,
|
217 |
+
"<|12.30|>": 50979,
|
218 |
+
"<|12.32|>": 50980,
|
219 |
+
"<|12.34|>": 50981,
|
220 |
+
"<|12.36|>": 50982,
|
221 |
+
"<|12.38|>": 50983,
|
222 |
+
"<|12.40|>": 50984,
|
223 |
+
"<|12.42|>": 50985,
|
224 |
+
"<|12.44|>": 50986,
|
225 |
+
"<|12.46|>": 50987,
|
226 |
+
"<|12.48|>": 50988,
|
227 |
+
"<|12.50|>": 50989,
|
228 |
+
"<|12.52|>": 50990,
|
229 |
+
"<|12.54|>": 50991,
|
230 |
+
"<|12.56|>": 50992,
|
231 |
+
"<|12.58|>": 50993,
|
232 |
+
"<|12.60|>": 50994,
|
233 |
+
"<|12.62|>": 50995,
|
234 |
+
"<|12.64|>": 50996,
|
235 |
+
"<|12.66|>": 50997,
|
236 |
+
"<|12.68|>": 50998,
|
237 |
+
"<|12.70|>": 50999,
|
238 |
+
"<|12.72|>": 51000,
|
239 |
+
"<|12.74|>": 51001,
|
240 |
+
"<|12.76|>": 51002,
|
241 |
+
"<|12.78|>": 51003,
|
242 |
+
"<|12.80|>": 51004,
|
243 |
+
"<|12.82|>": 51005,
|
244 |
+
"<|12.84|>": 51006,
|
245 |
+
"<|12.86|>": 51007,
|
246 |
+
"<|12.88|>": 51008,
|
247 |
+
"<|12.90|>": 51009,
|
248 |
+
"<|12.92|>": 51010,
|
249 |
+
"<|12.94|>": 51011,
|
250 |
+
"<|12.96|>": 51012,
|
251 |
+
"<|12.98|>": 51013,
|
252 |
+
"<|13.00|>": 51014,
|
253 |
+
"<|13.02|>": 51015,
|
254 |
+
"<|13.04|>": 51016,
|
255 |
+
"<|13.06|>": 51017,
|
256 |
+
"<|13.08|>": 51018,
|
257 |
+
"<|13.10|>": 51019,
|
258 |
+
"<|13.12|>": 51020,
|
259 |
+
"<|13.14|>": 51021,
|
260 |
+
"<|13.16|>": 51022,
|
261 |
+
"<|13.18|>": 51023,
|
262 |
+
"<|13.20|>": 51024,
|
263 |
+
"<|13.22|>": 51025,
|
264 |
+
"<|13.24|>": 51026,
|
265 |
+
"<|13.26|>": 51027,
|
266 |
+
"<|13.28|>": 51028,
|
267 |
+
"<|13.30|>": 51029,
|
268 |
+
"<|13.32|>": 51030,
|
269 |
+
"<|13.34|>": 51031,
|
270 |
+
"<|13.36|>": 51032,
|
271 |
+
"<|13.38|>": 51033,
|
272 |
+
"<|13.40|>": 51034,
|
273 |
+
"<|13.42|>": 51035,
|
274 |
+
"<|13.44|>": 51036,
|
275 |
+
"<|13.46|>": 51037,
|
276 |
+
"<|13.48|>": 51038,
|
277 |
+
"<|13.50|>": 51039,
|
278 |
+
"<|13.52|>": 51040,
|
279 |
+
"<|13.54|>": 51041,
|
280 |
+
"<|13.56|>": 51042,
|
281 |
+
"<|13.58|>": 51043,
|
282 |
+
"<|13.60|>": 51044,
|
283 |
+
"<|13.62|>": 51045,
|
284 |
+
"<|13.64|>": 51046,
|
285 |
+
"<|13.66|>": 51047,
|
286 |
+
"<|13.68|>": 51048,
|
287 |
+
"<|13.70|>": 51049,
|
288 |
+
"<|13.72|>": 51050,
|
289 |
+
"<|13.74|>": 51051,
|
290 |
+
"<|13.76|>": 51052,
|
291 |
+
"<|13.78|>": 51053,
|
292 |
+
"<|13.80|>": 51054,
|
293 |
+
"<|13.82|>": 51055,
|
294 |
+
"<|13.84|>": 51056,
|
295 |
+
"<|13.86|>": 51057,
|
296 |
+
"<|13.88|>": 51058,
|
297 |
+
"<|13.90|>": 51059,
|
298 |
+
"<|13.92|>": 51060,
|
299 |
+
"<|13.94|>": 51061,
|
300 |
+
"<|13.96|>": 51062,
|
301 |
+
"<|13.98|>": 51063,
|
302 |
+
"<|14.00|>": 51064,
|
303 |
+
"<|14.02|>": 51065,
|
304 |
+
"<|14.04|>": 51066,
|
305 |
+
"<|14.06|>": 51067,
|
306 |
+
"<|14.08|>": 51068,
|
307 |
+
"<|14.10|>": 51069,
|
308 |
+
"<|14.12|>": 51070,
|
309 |
+
"<|14.14|>": 51071,
|
310 |
+
"<|14.16|>": 51072,
|
311 |
+
"<|14.18|>": 51073,
|
312 |
+
"<|14.20|>": 51074,
|
313 |
+
"<|14.22|>": 51075,
|
314 |
+
"<|14.24|>": 51076,
|
315 |
+
"<|14.26|>": 51077,
|
316 |
+
"<|14.28|>": 51078,
|
317 |
+
"<|14.30|>": 51079,
|
318 |
+
"<|14.32|>": 51080,
|
319 |
+
"<|14.34|>": 51081,
|
320 |
+
"<|14.36|>": 51082,
|
321 |
+
"<|14.38|>": 51083,
|
322 |
+
"<|14.40|>": 51084,
|
323 |
+
"<|14.42|>": 51085,
|
324 |
+
"<|14.44|>": 51086,
|
325 |
+
"<|14.46|>": 51087,
|
326 |
+
"<|14.48|>": 51088,
|
327 |
+
"<|14.50|>": 51089,
|
328 |
+
"<|14.52|>": 51090,
|
329 |
+
"<|14.54|>": 51091,
|
330 |
+
"<|14.56|>": 51092,
|
331 |
+
"<|14.58|>": 51093,
|
332 |
+
"<|14.60|>": 51094,
|
333 |
+
"<|14.62|>": 51095,
|
334 |
+
"<|14.64|>": 51096,
|
335 |
+
"<|14.66|>": 51097,
|
336 |
+
"<|14.68|>": 51098,
|
337 |
+
"<|14.70|>": 51099,
|
338 |
+
"<|14.72|>": 51100,
|
339 |
+
"<|14.74|>": 51101,
|
340 |
+
"<|14.76|>": 51102,
|
341 |
+
"<|14.78|>": 51103,
|
342 |
+
"<|14.80|>": 51104,
|
343 |
+
"<|14.82|>": 51105,
|
344 |
+
"<|14.84|>": 51106,
|
345 |
+
"<|14.86|>": 51107,
|
346 |
+
"<|14.88|>": 51108,
|
347 |
+
"<|14.90|>": 51109,
|
348 |
+
"<|14.92|>": 51110,
|
349 |
+
"<|14.94|>": 51111,
|
350 |
+
"<|14.96|>": 51112,
|
351 |
+
"<|14.98|>": 51113,
|
352 |
+
"<|15.00|>": 51114,
|
353 |
+
"<|15.02|>": 51115,
|
354 |
+
"<|15.04|>": 51116,
|
355 |
+
"<|15.06|>": 51117,
|
356 |
+
"<|15.08|>": 51118,
|
357 |
+
"<|15.10|>": 51119,
|
358 |
+
"<|15.12|>": 51120,
|
359 |
+
"<|15.14|>": 51121,
|
360 |
+
"<|15.16|>": 51122,
|
361 |
+
"<|15.18|>": 51123,
|
362 |
+
"<|15.20|>": 51124,
|
363 |
+
"<|15.22|>": 51125,
|
364 |
+
"<|15.24|>": 51126,
|
365 |
+
"<|15.26|>": 51127,
|
366 |
+
"<|15.28|>": 51128,
|
367 |
+
"<|15.30|>": 51129,
|
368 |
+
"<|15.32|>": 51130,
|
369 |
+
"<|15.34|>": 51131,
|
370 |
+
"<|15.36|>": 51132,
|
371 |
+
"<|15.38|>": 51133,
|
372 |
+
"<|15.40|>": 51134,
|
373 |
+
"<|15.42|>": 51135,
|
374 |
+
"<|15.44|>": 51136,
|
375 |
+
"<|15.46|>": 51137,
|
376 |
+
"<|15.48|>": 51138,
|
377 |
+
"<|15.50|>": 51139,
|
378 |
+
"<|15.52|>": 51140,
|
379 |
+
"<|15.54|>": 51141,
|
380 |
+
"<|15.56|>": 51142,
|
381 |
+
"<|15.58|>": 51143,
|
382 |
+
"<|15.60|>": 51144,
|
383 |
+
"<|15.62|>": 51145,
|
384 |
+
"<|15.64|>": 51146,
|
385 |
+
"<|15.66|>": 51147,
|
386 |
+
"<|15.68|>": 51148,
|
387 |
+
"<|15.70|>": 51149,
|
388 |
+
"<|15.72|>": 51150,
|
389 |
+
"<|15.74|>": 51151,
|
390 |
+
"<|15.76|>": 51152,
|
391 |
+
"<|15.78|>": 51153,
|
392 |
+
"<|15.80|>": 51154,
|
393 |
+
"<|15.82|>": 51155,
|
394 |
+
"<|15.84|>": 51156,
|
395 |
+
"<|15.86|>": 51157,
|
396 |
+
"<|15.88|>": 51158,
|
397 |
+
"<|15.90|>": 51159,
|
398 |
+
"<|15.92|>": 51160,
|
399 |
+
"<|15.94|>": 51161,
|
400 |
+
"<|15.96|>": 51162,
|
401 |
+
"<|15.98|>": 51163,
|
402 |
+
"<|16.00|>": 51164,
|
403 |
+
"<|16.02|>": 51165,
|
404 |
+
"<|16.04|>": 51166,
|
405 |
+
"<|16.06|>": 51167,
|
406 |
+
"<|16.08|>": 51168,
|
407 |
+
"<|16.10|>": 51169,
|
408 |
+
"<|16.12|>": 51170,
|
409 |
+
"<|16.14|>": 51171,
|
410 |
+
"<|16.16|>": 51172,
|
411 |
+
"<|16.18|>": 51173,
|
412 |
+
"<|16.20|>": 51174,
|
413 |
+
"<|16.22|>": 51175,
|
414 |
+
"<|16.24|>": 51176,
|
415 |
+
"<|16.26|>": 51177,
|
416 |
+
"<|16.28|>": 51178,
|
417 |
+
"<|16.30|>": 51179,
|
418 |
+
"<|16.32|>": 51180,
|
419 |
+
"<|16.34|>": 51181,
|
420 |
+
"<|16.36|>": 51182,
|
421 |
+
"<|16.38|>": 51183,
|
422 |
+
"<|16.40|>": 51184,
|
423 |
+
"<|16.42|>": 51185,
|
424 |
+
"<|16.44|>": 51186,
|
425 |
+
"<|16.46|>": 51187,
|
426 |
+
"<|16.48|>": 51188,
|
427 |
+
"<|16.50|>": 51189,
|
428 |
+
"<|16.52|>": 51190,
|
429 |
+
"<|16.54|>": 51191,
|
430 |
+
"<|16.56|>": 51192,
|
431 |
+
"<|16.58|>": 51193,
|
432 |
+
"<|16.60|>": 51194,
|
433 |
+
"<|16.62|>": 51195,
|
434 |
+
"<|16.64|>": 51196,
|
435 |
+
"<|16.66|>": 51197,
|
436 |
+
"<|16.68|>": 51198,
|
437 |
+
"<|16.70|>": 51199,
|
438 |
+
"<|16.72|>": 51200,
|
439 |
+
"<|16.74|>": 51201,
|
440 |
+
"<|16.76|>": 51202,
|
441 |
+
"<|16.78|>": 51203,
|
442 |
+
"<|16.80|>": 51204,
|
443 |
+
"<|16.82|>": 51205,
|
444 |
+
"<|16.84|>": 51206,
|
445 |
+
"<|16.86|>": 51207,
|
446 |
+
"<|16.88|>": 51208,
|
447 |
+
"<|16.90|>": 51209,
|
448 |
+
"<|16.92|>": 51210,
|
449 |
+
"<|16.94|>": 51211,
|
450 |
+
"<|16.96|>": 51212,
|
451 |
+
"<|16.98|>": 51213,
|
452 |
+
"<|17.00|>": 51214,
|
453 |
+
"<|17.02|>": 51215,
|
454 |
+
"<|17.04|>": 51216,
|
455 |
+
"<|17.06|>": 51217,
|
456 |
+
"<|17.08|>": 51218,
|
457 |
+
"<|17.10|>": 51219,
|
458 |
+
"<|17.12|>": 51220,
|
459 |
+
"<|17.14|>": 51221,
|
460 |
+
"<|17.16|>": 51222,
|
461 |
+
"<|17.18|>": 51223,
|
462 |
+
"<|17.20|>": 51224,
|
463 |
+
"<|17.22|>": 51225,
|
464 |
+
"<|17.24|>": 51226,
|
465 |
+
"<|17.26|>": 51227,
|
466 |
+
"<|17.28|>": 51228,
|
467 |
+
"<|17.30|>": 51229,
|
468 |
+
"<|17.32|>": 51230,
|
469 |
+
"<|17.34|>": 51231,
|
470 |
+
"<|17.36|>": 51232,
|
471 |
+
"<|17.38|>": 51233,
|
472 |
+
"<|17.40|>": 51234,
|
473 |
+
"<|17.42|>": 51235,
|
474 |
+
"<|17.44|>": 51236,
|
475 |
+
"<|17.46|>": 51237,
|
476 |
+
"<|17.48|>": 51238,
|
477 |
+
"<|17.50|>": 51239,
|
478 |
+
"<|17.52|>": 51240,
|
479 |
+
"<|17.54|>": 51241,
|
480 |
+
"<|17.56|>": 51242,
|
481 |
+
"<|17.58|>": 51243,
|
482 |
+
"<|17.60|>": 51244,
|
483 |
+
"<|17.62|>": 51245,
|
484 |
+
"<|17.64|>": 51246,
|
485 |
+
"<|17.66|>": 51247,
|
486 |
+
"<|17.68|>": 51248,
|
487 |
+
"<|17.70|>": 51249,
|
488 |
+
"<|17.72|>": 51250,
|
489 |
+
"<|17.74|>": 51251,
|
490 |
+
"<|17.76|>": 51252,
|
491 |
+
"<|17.78|>": 51253,
|
492 |
+
"<|17.80|>": 51254,
|
493 |
+
"<|17.82|>": 51255,
|
494 |
+
"<|17.84|>": 51256,
|
495 |
+
"<|17.86|>": 51257,
|
496 |
+
"<|17.88|>": 51258,
|
497 |
+
"<|17.90|>": 51259,
|
498 |
+
"<|17.92|>": 51260,
|
499 |
+
"<|17.94|>": 51261,
|
500 |
+
"<|17.96|>": 51262,
|
501 |
+
"<|17.98|>": 51263,
|
502 |
+
"<|18.00|>": 51264,
|
503 |
+
"<|18.02|>": 51265,
|
504 |
+
"<|18.04|>": 51266,
|
505 |
+
"<|18.06|>": 51267,
|
506 |
+
"<|18.08|>": 51268,
|
507 |
+
"<|18.10|>": 51269,
|
508 |
+
"<|18.12|>": 51270,
|
509 |
+
"<|18.14|>": 51271,
|
510 |
+
"<|18.16|>": 51272,
|
511 |
+
"<|18.18|>": 51273,
|
512 |
+
"<|18.20|>": 51274,
|
513 |
+
"<|18.22|>": 51275,
|
514 |
+
"<|18.24|>": 51276,
|
515 |
+
"<|18.26|>": 51277,
|
516 |
+
"<|18.28|>": 51278,
|
517 |
+
"<|18.30|>": 51279,
|
518 |
+
"<|18.32|>": 51280,
|
519 |
+
"<|18.34|>": 51281,
|
520 |
+
"<|18.36|>": 51282,
|
521 |
+
"<|18.38|>": 51283,
|
522 |
+
"<|18.40|>": 51284,
|
523 |
+
"<|18.42|>": 51285,
|
524 |
+
"<|18.44|>": 51286,
|
525 |
+
"<|18.46|>": 51287,
|
526 |
+
"<|18.48|>": 51288,
|
527 |
+
"<|18.50|>": 51289,
|
528 |
+
"<|18.52|>": 51290,
|
529 |
+
"<|18.54|>": 51291,
|
530 |
+
"<|18.56|>": 51292,
|
531 |
+
"<|18.58|>": 51293,
|
532 |
+
"<|18.60|>": 51294,
|
533 |
+
"<|18.62|>": 51295,
|
534 |
+
"<|18.64|>": 51296,
|
535 |
+
"<|18.66|>": 51297,
|
536 |
+
"<|18.68|>": 51298,
|
537 |
+
"<|18.70|>": 51299,
|
538 |
+
"<|18.72|>": 51300,
|
539 |
+
"<|18.74|>": 51301,
|
540 |
+
"<|18.76|>": 51302,
|
541 |
+
"<|18.78|>": 51303,
|
542 |
+
"<|18.80|>": 51304,
|
543 |
+
"<|18.82|>": 51305,
|
544 |
+
"<|18.84|>": 51306,
|
545 |
+
"<|18.86|>": 51307,
|
546 |
+
"<|18.88|>": 51308,
|
547 |
+
"<|18.90|>": 51309,
|
548 |
+
"<|18.92|>": 51310,
|
549 |
+
"<|18.94|>": 51311,
|
550 |
+
"<|18.96|>": 51312,
|
551 |
+
"<|18.98|>": 51313,
|
552 |
+
"<|19.00|>": 51314,
|
553 |
+
"<|19.02|>": 51315,
|
554 |
+
"<|19.04|>": 51316,
|
555 |
+
"<|19.06|>": 51317,
|
556 |
+
"<|19.08|>": 51318,
|
557 |
+
"<|19.10|>": 51319,
|
558 |
+
"<|19.12|>": 51320,
|
559 |
+
"<|19.14|>": 51321,
|
560 |
+
"<|19.16|>": 51322,
|
561 |
+
"<|19.18|>": 51323,
|
562 |
+
"<|19.20|>": 51324,
|
563 |
+
"<|19.22|>": 51325,
|
564 |
+
"<|19.24|>": 51326,
|
565 |
+
"<|19.26|>": 51327,
|
566 |
+
"<|19.28|>": 51328,
|
567 |
+
"<|19.30|>": 51329,
|
568 |
+
"<|19.32|>": 51330,
|
569 |
+
"<|19.34|>": 51331,
|
570 |
+
"<|19.36|>": 51332,
|
571 |
+
"<|19.38|>": 51333,
|
572 |
+
"<|19.40|>": 51334,
|
573 |
+
"<|19.42|>": 51335,
|
574 |
+
"<|19.44|>": 51336,
|
575 |
+
"<|19.46|>": 51337,
|
576 |
+
"<|19.48|>": 51338,
|
577 |
+
"<|19.50|>": 51339,
|
578 |
+
"<|19.52|>": 51340,
|
579 |
+
"<|19.54|>": 51341,
|
580 |
+
"<|19.56|>": 51342,
|
581 |
+
"<|19.58|>": 51343,
|
582 |
+
"<|19.60|>": 51344,
|
583 |
+
"<|19.62|>": 51345,
|
584 |
+
"<|19.64|>": 51346,
|
585 |
+
"<|19.66|>": 51347,
|
586 |
+
"<|19.68|>": 51348,
|
587 |
+
"<|19.70|>": 51349,
|
588 |
+
"<|19.72|>": 51350,
|
589 |
+
"<|19.74|>": 51351,
|
590 |
+
"<|19.76|>": 51352,
|
591 |
+
"<|19.78|>": 51353,
|
592 |
+
"<|19.80|>": 51354,
|
593 |
+
"<|19.82|>": 51355,
|
594 |
+
"<|19.84|>": 51356,
|
595 |
+
"<|19.86|>": 51357,
|
596 |
+
"<|19.88|>": 51358,
|
597 |
+
"<|19.90|>": 51359,
|
598 |
+
"<|19.92|>": 51360,
|
599 |
+
"<|19.94|>": 51361,
|
600 |
+
"<|19.96|>": 51362,
|
601 |
+
"<|19.98|>": 51363,
|
602 |
+
"<|2.00|>": 50464,
|
603 |
+
"<|2.02|>": 50465,
|
604 |
+
"<|2.04|>": 50466,
|
605 |
+
"<|2.06|>": 50467,
|
606 |
+
"<|2.08|>": 50468,
|
607 |
+
"<|2.10|>": 50469,
|
608 |
+
"<|2.12|>": 50470,
|
609 |
+
"<|2.14|>": 50471,
|
610 |
+
"<|2.16|>": 50472,
|
611 |
+
"<|2.18|>": 50473,
|
612 |
+
"<|2.20|>": 50474,
|
613 |
+
"<|2.22|>": 50475,
|
614 |
+
"<|2.24|>": 50476,
|
615 |
+
"<|2.26|>": 50477,
|
616 |
+
"<|2.28|>": 50478,
|
617 |
+
"<|2.30|>": 50479,
|
618 |
+
"<|2.32|>": 50480,
|
619 |
+
"<|2.34|>": 50481,
|
620 |
+
"<|2.36|>": 50482,
|
621 |
+
"<|2.38|>": 50483,
|
622 |
+
"<|2.40|>": 50484,
|
623 |
+
"<|2.42|>": 50485,
|
624 |
+
"<|2.44|>": 50486,
|
625 |
+
"<|2.46|>": 50487,
|
626 |
+
"<|2.48|>": 50488,
|
627 |
+
"<|2.50|>": 50489,
|
628 |
+
"<|2.52|>": 50490,
|
629 |
+
"<|2.54|>": 50491,
|
630 |
+
"<|2.56|>": 50492,
|
631 |
+
"<|2.58|>": 50493,
|
632 |
+
"<|2.60|>": 50494,
|
633 |
+
"<|2.62|>": 50495,
|
634 |
+
"<|2.64|>": 50496,
|
635 |
+
"<|2.66|>": 50497,
|
636 |
+
"<|2.68|>": 50498,
|
637 |
+
"<|2.70|>": 50499,
|
638 |
+
"<|2.72|>": 50500,
|
639 |
+
"<|2.74|>": 50501,
|
640 |
+
"<|2.76|>": 50502,
|
641 |
+
"<|2.78|>": 50503,
|
642 |
+
"<|2.80|>": 50504,
|
643 |
+
"<|2.82|>": 50505,
|
644 |
+
"<|2.84|>": 50506,
|
645 |
+
"<|2.86|>": 50507,
|
646 |
+
"<|2.88|>": 50508,
|
647 |
+
"<|2.90|>": 50509,
|
648 |
+
"<|2.92|>": 50510,
|
649 |
+
"<|2.94|>": 50511,
|
650 |
+
"<|2.96|>": 50512,
|
651 |
+
"<|2.98|>": 50513,
|
652 |
+
"<|20.00|>": 51364,
|
653 |
+
"<|20.02|>": 51365,
|
654 |
+
"<|20.04|>": 51366,
|
655 |
+
"<|20.06|>": 51367,
|
656 |
+
"<|20.08|>": 51368,
|
657 |
+
"<|20.10|>": 51369,
|
658 |
+
"<|20.12|>": 51370,
|
659 |
+
"<|20.14|>": 51371,
|
660 |
+
"<|20.16|>": 51372,
|
661 |
+
"<|20.18|>": 51373,
|
662 |
+
"<|20.20|>": 51374,
|
663 |
+
"<|20.22|>": 51375,
|
664 |
+
"<|20.24|>": 51376,
|
665 |
+
"<|20.26|>": 51377,
|
666 |
+
"<|20.28|>": 51378,
|
667 |
+
"<|20.30|>": 51379,
|
668 |
+
"<|20.32|>": 51380,
|
669 |
+
"<|20.34|>": 51381,
|
670 |
+
"<|20.36|>": 51382,
|
671 |
+
"<|20.38|>": 51383,
|
672 |
+
"<|20.40|>": 51384,
|
673 |
+
"<|20.42|>": 51385,
|
674 |
+
"<|20.44|>": 51386,
|
675 |
+
"<|20.46|>": 51387,
|
676 |
+
"<|20.48|>": 51388,
|
677 |
+
"<|20.50|>": 51389,
|
678 |
+
"<|20.52|>": 51390,
|
679 |
+
"<|20.54|>": 51391,
|
680 |
+
"<|20.56|>": 51392,
|
681 |
+
"<|20.58|>": 51393,
|
682 |
+
"<|20.60|>": 51394,
|
683 |
+
"<|20.62|>": 51395,
|
684 |
+
"<|20.64|>": 51396,
|
685 |
+
"<|20.66|>": 51397,
|
686 |
+
"<|20.68|>": 51398,
|
687 |
+
"<|20.70|>": 51399,
|
688 |
+
"<|20.72|>": 51400,
|
689 |
+
"<|20.74|>": 51401,
|
690 |
+
"<|20.76|>": 51402,
|
691 |
+
"<|20.78|>": 51403,
|
692 |
+
"<|20.80|>": 51404,
|
693 |
+
"<|20.82|>": 51405,
|
694 |
+
"<|20.84|>": 51406,
|
695 |
+
"<|20.86|>": 51407,
|
696 |
+
"<|20.88|>": 51408,
|
697 |
+
"<|20.90|>": 51409,
|
698 |
+
"<|20.92|>": 51410,
|
699 |
+
"<|20.94|>": 51411,
|
700 |
+
"<|20.96|>": 51412,
|
701 |
+
"<|20.98|>": 51413,
|
702 |
+
"<|21.00|>": 51414,
|
703 |
+
"<|21.02|>": 51415,
|
704 |
+
"<|21.04|>": 51416,
|
705 |
+
"<|21.06|>": 51417,
|
706 |
+
"<|21.08|>": 51418,
|
707 |
+
"<|21.10|>": 51419,
|
708 |
+
"<|21.12|>": 51420,
|
709 |
+
"<|21.14|>": 51421,
|
710 |
+
"<|21.16|>": 51422,
|
711 |
+
"<|21.18|>": 51423,
|
712 |
+
"<|21.20|>": 51424,
|
713 |
+
"<|21.22|>": 51425,
|
714 |
+
"<|21.24|>": 51426,
|
715 |
+
"<|21.26|>": 51427,
|
716 |
+
"<|21.28|>": 51428,
|
717 |
+
"<|21.30|>": 51429,
|
718 |
+
"<|21.32|>": 51430,
|
719 |
+
"<|21.34|>": 51431,
|
720 |
+
"<|21.36|>": 51432,
|
721 |
+
"<|21.38|>": 51433,
|
722 |
+
"<|21.40|>": 51434,
|
723 |
+
"<|21.42|>": 51435,
|
724 |
+
"<|21.44|>": 51436,
|
725 |
+
"<|21.46|>": 51437,
|
726 |
+
"<|21.48|>": 51438,
|
727 |
+
"<|21.50|>": 51439,
|
728 |
+
"<|21.52|>": 51440,
|
729 |
+
"<|21.54|>": 51441,
|
730 |
+
"<|21.56|>": 51442,
|
731 |
+
"<|21.58|>": 51443,
|
732 |
+
"<|21.60|>": 51444,
|
733 |
+
"<|21.62|>": 51445,
|
734 |
+
"<|21.64|>": 51446,
|
735 |
+
"<|21.66|>": 51447,
|
736 |
+
"<|21.68|>": 51448,
|
737 |
+
"<|21.70|>": 51449,
|
738 |
+
"<|21.72|>": 51450,
|
739 |
+
"<|21.74|>": 51451,
|
740 |
+
"<|21.76|>": 51452,
|
741 |
+
"<|21.78|>": 51453,
|
742 |
+
"<|21.80|>": 51454,
|
743 |
+
"<|21.82|>": 51455,
|
744 |
+
"<|21.84|>": 51456,
|
745 |
+
"<|21.86|>": 51457,
|
746 |
+
"<|21.88|>": 51458,
|
747 |
+
"<|21.90|>": 51459,
|
748 |
+
"<|21.92|>": 51460,
|
749 |
+
"<|21.94|>": 51461,
|
750 |
+
"<|21.96|>": 51462,
|
751 |
+
"<|21.98|>": 51463,
|
752 |
+
"<|22.00|>": 51464,
|
753 |
+
"<|22.02|>": 51465,
|
754 |
+
"<|22.04|>": 51466,
|
755 |
+
"<|22.06|>": 51467,
|
756 |
+
"<|22.08|>": 51468,
|
757 |
+
"<|22.10|>": 51469,
|
758 |
+
"<|22.12|>": 51470,
|
759 |
+
"<|22.14|>": 51471,
|
760 |
+
"<|22.16|>": 51472,
|
761 |
+
"<|22.18|>": 51473,
|
762 |
+
"<|22.20|>": 51474,
|
763 |
+
"<|22.22|>": 51475,
|
764 |
+
"<|22.24|>": 51476,
|
765 |
+
"<|22.26|>": 51477,
|
766 |
+
"<|22.28|>": 51478,
|
767 |
+
"<|22.30|>": 51479,
|
768 |
+
"<|22.32|>": 51480,
|
769 |
+
"<|22.34|>": 51481,
|
770 |
+
"<|22.36|>": 51482,
|
771 |
+
"<|22.38|>": 51483,
|
772 |
+
"<|22.40|>": 51484,
|
773 |
+
"<|22.42|>": 51485,
|
774 |
+
"<|22.44|>": 51486,
|
775 |
+
"<|22.46|>": 51487,
|
776 |
+
"<|22.48|>": 51488,
|
777 |
+
"<|22.50|>": 51489,
|
778 |
+
"<|22.52|>": 51490,
|
779 |
+
"<|22.54|>": 51491,
|
780 |
+
"<|22.56|>": 51492,
|
781 |
+
"<|22.58|>": 51493,
|
782 |
+
"<|22.60|>": 51494,
|
783 |
+
"<|22.62|>": 51495,
|
784 |
+
"<|22.64|>": 51496,
|
785 |
+
"<|22.66|>": 51497,
|
786 |
+
"<|22.68|>": 51498,
|
787 |
+
"<|22.70|>": 51499,
|
788 |
+
"<|22.72|>": 51500,
|
789 |
+
"<|22.74|>": 51501,
|
790 |
+
"<|22.76|>": 51502,
|
791 |
+
"<|22.78|>": 51503,
|
792 |
+
"<|22.80|>": 51504,
|
793 |
+
"<|22.82|>": 51505,
|
794 |
+
"<|22.84|>": 51506,
|
795 |
+
"<|22.86|>": 51507,
|
796 |
+
"<|22.88|>": 51508,
|
797 |
+
"<|22.90|>": 51509,
|
798 |
+
"<|22.92|>": 51510,
|
799 |
+
"<|22.94|>": 51511,
|
800 |
+
"<|22.96|>": 51512,
|
801 |
+
"<|22.98|>": 51513,
|
802 |
+
"<|23.00|>": 51514,
|
803 |
+
"<|23.02|>": 51515,
|
804 |
+
"<|23.04|>": 51516,
|
805 |
+
"<|23.06|>": 51517,
|
806 |
+
"<|23.08|>": 51518,
|
807 |
+
"<|23.10|>": 51519,
|
808 |
+
"<|23.12|>": 51520,
|
809 |
+
"<|23.14|>": 51521,
|
810 |
+
"<|23.16|>": 51522,
|
811 |
+
"<|23.18|>": 51523,
|
812 |
+
"<|23.20|>": 51524,
|
813 |
+
"<|23.22|>": 51525,
|
814 |
+
"<|23.24|>": 51526,
|
815 |
+
"<|23.26|>": 51527,
|
816 |
+
"<|23.28|>": 51528,
|
817 |
+
"<|23.30|>": 51529,
|
818 |
+
"<|23.32|>": 51530,
|
819 |
+
"<|23.34|>": 51531,
|
820 |
+
"<|23.36|>": 51532,
|
821 |
+
"<|23.38|>": 51533,
|
822 |
+
"<|23.40|>": 51534,
|
823 |
+
"<|23.42|>": 51535,
|
824 |
+
"<|23.44|>": 51536,
|
825 |
+
"<|23.46|>": 51537,
|
826 |
+
"<|23.48|>": 51538,
|
827 |
+
"<|23.50|>": 51539,
|
828 |
+
"<|23.52|>": 51540,
|
829 |
+
"<|23.54|>": 51541,
|
830 |
+
"<|23.56|>": 51542,
|
831 |
+
"<|23.58|>": 51543,
|
832 |
+
"<|23.60|>": 51544,
|
833 |
+
"<|23.62|>": 51545,
|
834 |
+
"<|23.64|>": 51546,
|
835 |
+
"<|23.66|>": 51547,
|
836 |
+
"<|23.68|>": 51548,
|
837 |
+
"<|23.70|>": 51549,
|
838 |
+
"<|23.72|>": 51550,
|
839 |
+
"<|23.74|>": 51551,
|
840 |
+
"<|23.76|>": 51552,
|
841 |
+
"<|23.78|>": 51553,
|
842 |
+
"<|23.80|>": 51554,
|
843 |
+
"<|23.82|>": 51555,
|
844 |
+
"<|23.84|>": 51556,
|
845 |
+
"<|23.86|>": 51557,
|
846 |
+
"<|23.88|>": 51558,
|
847 |
+
"<|23.90|>": 51559,
|
848 |
+
"<|23.92|>": 51560,
|
849 |
+
"<|23.94|>": 51561,
|
850 |
+
"<|23.96|>": 51562,
|
851 |
+
"<|23.98|>": 51563,
|
852 |
+
"<|24.00|>": 51564,
|
853 |
+
"<|24.02|>": 51565,
|
854 |
+
"<|24.04|>": 51566,
|
855 |
+
"<|24.06|>": 51567,
|
856 |
+
"<|24.08|>": 51568,
|
857 |
+
"<|24.10|>": 51569,
|
858 |
+
"<|24.12|>": 51570,
|
859 |
+
"<|24.14|>": 51571,
|
860 |
+
"<|24.16|>": 51572,
|
861 |
+
"<|24.18|>": 51573,
|
862 |
+
"<|24.20|>": 51574,
|
863 |
+
"<|24.22|>": 51575,
|
864 |
+
"<|24.24|>": 51576,
|
865 |
+
"<|24.26|>": 51577,
|
866 |
+
"<|24.28|>": 51578,
|
867 |
+
"<|24.30|>": 51579,
|
868 |
+
"<|24.32|>": 51580,
|
869 |
+
"<|24.34|>": 51581,
|
870 |
+
"<|24.36|>": 51582,
|
871 |
+
"<|24.38|>": 51583,
|
872 |
+
"<|24.40|>": 51584,
|
873 |
+
"<|24.42|>": 51585,
|
874 |
+
"<|24.44|>": 51586,
|
875 |
+
"<|24.46|>": 51587,
|
876 |
+
"<|24.48|>": 51588,
|
877 |
+
"<|24.50|>": 51589,
|
878 |
+
"<|24.52|>": 51590,
|
879 |
+
"<|24.54|>": 51591,
|
880 |
+
"<|24.56|>": 51592,
|
881 |
+
"<|24.58|>": 51593,
|
882 |
+
"<|24.60|>": 51594,
|
883 |
+
"<|24.62|>": 51595,
|
884 |
+
"<|24.64|>": 51596,
|
885 |
+
"<|24.66|>": 51597,
|
886 |
+
"<|24.68|>": 51598,
|
887 |
+
"<|24.70|>": 51599,
|
888 |
+
"<|24.72|>": 51600,
|
889 |
+
"<|24.74|>": 51601,
|
890 |
+
"<|24.76|>": 51602,
|
891 |
+
"<|24.78|>": 51603,
|
892 |
+
"<|24.80|>": 51604,
|
893 |
+
"<|24.82|>": 51605,
|
894 |
+
"<|24.84|>": 51606,
|
895 |
+
"<|24.86|>": 51607,
|
896 |
+
"<|24.88|>": 51608,
|
897 |
+
"<|24.90|>": 51609,
|
898 |
+
"<|24.92|>": 51610,
|
899 |
+
"<|24.94|>": 51611,
|
900 |
+
"<|24.96|>": 51612,
|
901 |
+
"<|24.98|>": 51613,
|
902 |
+
"<|25.00|>": 51614,
|
903 |
+
"<|25.02|>": 51615,
|
904 |
+
"<|25.04|>": 51616,
|
905 |
+
"<|25.06|>": 51617,
|
906 |
+
"<|25.08|>": 51618,
|
907 |
+
"<|25.10|>": 51619,
|
908 |
+
"<|25.12|>": 51620,
|
909 |
+
"<|25.14|>": 51621,
|
910 |
+
"<|25.16|>": 51622,
|
911 |
+
"<|25.18|>": 51623,
|
912 |
+
"<|25.20|>": 51624,
|
913 |
+
"<|25.22|>": 51625,
|
914 |
+
"<|25.24|>": 51626,
|
915 |
+
"<|25.26|>": 51627,
|
916 |
+
"<|25.28|>": 51628,
|
917 |
+
"<|25.30|>": 51629,
|
918 |
+
"<|25.32|>": 51630,
|
919 |
+
"<|25.34|>": 51631,
|
920 |
+
"<|25.36|>": 51632,
|
921 |
+
"<|25.38|>": 51633,
|
922 |
+
"<|25.40|>": 51634,
|
923 |
+
"<|25.42|>": 51635,
|
924 |
+
"<|25.44|>": 51636,
|
925 |
+
"<|25.46|>": 51637,
|
926 |
+
"<|25.48|>": 51638,
|
927 |
+
"<|25.50|>": 51639,
|
928 |
+
"<|25.52|>": 51640,
|
929 |
+
"<|25.54|>": 51641,
|
930 |
+
"<|25.56|>": 51642,
|
931 |
+
"<|25.58|>": 51643,
|
932 |
+
"<|25.60|>": 51644,
|
933 |
+
"<|25.62|>": 51645,
|
934 |
+
"<|25.64|>": 51646,
|
935 |
+
"<|25.66|>": 51647,
|
936 |
+
"<|25.68|>": 51648,
|
937 |
+
"<|25.70|>": 51649,
|
938 |
+
"<|25.72|>": 51650,
|
939 |
+
"<|25.74|>": 51651,
|
940 |
+
"<|25.76|>": 51652,
|
941 |
+
"<|25.78|>": 51653,
|
942 |
+
"<|25.80|>": 51654,
|
943 |
+
"<|25.82|>": 51655,
|
944 |
+
"<|25.84|>": 51656,
|
945 |
+
"<|25.86|>": 51657,
|
946 |
+
"<|25.88|>": 51658,
|
947 |
+
"<|25.90|>": 51659,
|
948 |
+
"<|25.92|>": 51660,
|
949 |
+
"<|25.94|>": 51661,
|
950 |
+
"<|25.96|>": 51662,
|
951 |
+
"<|25.98|>": 51663,
|
952 |
+
"<|26.00|>": 51664,
|
953 |
+
"<|26.02|>": 51665,
|
954 |
+
"<|26.04|>": 51666,
|
955 |
+
"<|26.06|>": 51667,
|
956 |
+
"<|26.08|>": 51668,
|
957 |
+
"<|26.10|>": 51669,
|
958 |
+
"<|26.12|>": 51670,
|
959 |
+
"<|26.14|>": 51671,
|
960 |
+
"<|26.16|>": 51672,
|
961 |
+
"<|26.18|>": 51673,
|
962 |
+
"<|26.20|>": 51674,
|
963 |
+
"<|26.22|>": 51675,
|
964 |
+
"<|26.24|>": 51676,
|
965 |
+
"<|26.26|>": 51677,
|
966 |
+
"<|26.28|>": 51678,
|
967 |
+
"<|26.30|>": 51679,
|
968 |
+
"<|26.32|>": 51680,
|
969 |
+
"<|26.34|>": 51681,
|
970 |
+
"<|26.36|>": 51682,
|
971 |
+
"<|26.38|>": 51683,
|
972 |
+
"<|26.40|>": 51684,
|
973 |
+
"<|26.42|>": 51685,
|
974 |
+
"<|26.44|>": 51686,
|
975 |
+
"<|26.46|>": 51687,
|
976 |
+
"<|26.48|>": 51688,
|
977 |
+
"<|26.50|>": 51689,
|
978 |
+
"<|26.52|>": 51690,
|
979 |
+
"<|26.54|>": 51691,
|
980 |
+
"<|26.56|>": 51692,
|
981 |
+
"<|26.58|>": 51693,
|
982 |
+
"<|26.60|>": 51694,
|
983 |
+
"<|26.62|>": 51695,
|
984 |
+
"<|26.64|>": 51696,
|
985 |
+
"<|26.66|>": 51697,
|
986 |
+
"<|26.68|>": 51698,
|
987 |
+
"<|26.70|>": 51699,
|
988 |
+
"<|26.72|>": 51700,
|
989 |
+
"<|26.74|>": 51701,
|
990 |
+
"<|26.76|>": 51702,
|
991 |
+
"<|26.78|>": 51703,
|
992 |
+
"<|26.80|>": 51704,
|
993 |
+
"<|26.82|>": 51705,
|
994 |
+
"<|26.84|>": 51706,
|
995 |
+
"<|26.86|>": 51707,
|
996 |
+
"<|26.88|>": 51708,
|
997 |
+
"<|26.90|>": 51709,
|
998 |
+
"<|26.92|>": 51710,
|
999 |
+
"<|26.94|>": 51711,
|
1000 |
+
"<|26.96|>": 51712,
|
1001 |
+
"<|26.98|>": 51713,
|
1002 |
+
"<|27.00|>": 51714,
|
1003 |
+
"<|27.02|>": 51715,
|
1004 |
+
"<|27.04|>": 51716,
|
1005 |
+
"<|27.06|>": 51717,
|
1006 |
+
"<|27.08|>": 51718,
|
1007 |
+
"<|27.10|>": 51719,
|
1008 |
+
"<|27.12|>": 51720,
|
1009 |
+
"<|27.14|>": 51721,
|
1010 |
+
"<|27.16|>": 51722,
|
1011 |
+
"<|27.18|>": 51723,
|
1012 |
+
"<|27.20|>": 51724,
|
1013 |
+
"<|27.22|>": 51725,
|
1014 |
+
"<|27.24|>": 51726,
|
1015 |
+
"<|27.26|>": 51727,
|
1016 |
+
"<|27.28|>": 51728,
|
1017 |
+
"<|27.30|>": 51729,
|
1018 |
+
"<|27.32|>": 51730,
|
1019 |
+
"<|27.34|>": 51731,
|
1020 |
+
"<|27.36|>": 51732,
|
1021 |
+
"<|27.38|>": 51733,
|
1022 |
+
"<|27.40|>": 51734,
|
1023 |
+
"<|27.42|>": 51735,
|
1024 |
+
"<|27.44|>": 51736,
|
1025 |
+
"<|27.46|>": 51737,
|
1026 |
+
"<|27.48|>": 51738,
|
1027 |
+
"<|27.50|>": 51739,
|
1028 |
+
"<|27.52|>": 51740,
|
1029 |
+
"<|27.54|>": 51741,
|
1030 |
+
"<|27.56|>": 51742,
|
1031 |
+
"<|27.58|>": 51743,
|
1032 |
+
"<|27.60|>": 51744,
|
1033 |
+
"<|27.62|>": 51745,
|
1034 |
+
"<|27.64|>": 51746,
|
1035 |
+
"<|27.66|>": 51747,
|
1036 |
+
"<|27.68|>": 51748,
|
1037 |
+
"<|27.70|>": 51749,
|
1038 |
+
"<|27.72|>": 51750,
|
1039 |
+
"<|27.74|>": 51751,
|
1040 |
+
"<|27.76|>": 51752,
|
1041 |
+
"<|27.78|>": 51753,
|
1042 |
+
"<|27.80|>": 51754,
|
1043 |
+
"<|27.82|>": 51755,
|
1044 |
+
"<|27.84|>": 51756,
|
1045 |
+
"<|27.86|>": 51757,
|
1046 |
+
"<|27.88|>": 51758,
|
1047 |
+
"<|27.90|>": 51759,
|
1048 |
+
"<|27.92|>": 51760,
|
1049 |
+
"<|27.94|>": 51761,
|
1050 |
+
"<|27.96|>": 51762,
|
1051 |
+
"<|27.98|>": 51763,
|
1052 |
+
"<|28.00|>": 51764,
|
1053 |
+
"<|28.02|>": 51765,
|
1054 |
+
"<|28.04|>": 51766,
|
1055 |
+
"<|28.06|>": 51767,
|
1056 |
+
"<|28.08|>": 51768,
|
1057 |
+
"<|28.10|>": 51769,
|
1058 |
+
"<|28.12|>": 51770,
|
1059 |
+
"<|28.14|>": 51771,
|
1060 |
+
"<|28.16|>": 51772,
|
1061 |
+
"<|28.18|>": 51773,
|
1062 |
+
"<|28.20|>": 51774,
|
1063 |
+
"<|28.22|>": 51775,
|
1064 |
+
"<|28.24|>": 51776,
|
1065 |
+
"<|28.26|>": 51777,
|
1066 |
+
"<|28.28|>": 51778,
|
1067 |
+
"<|28.30|>": 51779,
|
1068 |
+
"<|28.32|>": 51780,
|
1069 |
+
"<|28.34|>": 51781,
|
1070 |
+
"<|28.36|>": 51782,
|
1071 |
+
"<|28.38|>": 51783,
|
1072 |
+
"<|28.40|>": 51784,
|
1073 |
+
"<|28.42|>": 51785,
|
1074 |
+
"<|28.44|>": 51786,
|
1075 |
+
"<|28.46|>": 51787,
|
1076 |
+
"<|28.48|>": 51788,
|
1077 |
+
"<|28.50|>": 51789,
|
1078 |
+
"<|28.52|>": 51790,
|
1079 |
+
"<|28.54|>": 51791,
|
1080 |
+
"<|28.56|>": 51792,
|
1081 |
+
"<|28.58|>": 51793,
|
1082 |
+
"<|28.60|>": 51794,
|
1083 |
+
"<|28.62|>": 51795,
|
1084 |
+
"<|28.64|>": 51796,
|
1085 |
+
"<|28.66|>": 51797,
|
1086 |
+
"<|28.68|>": 51798,
|
1087 |
+
"<|28.70|>": 51799,
|
1088 |
+
"<|28.72|>": 51800,
|
1089 |
+
"<|28.74|>": 51801,
|
1090 |
+
"<|28.76|>": 51802,
|
1091 |
+
"<|28.78|>": 51803,
|
1092 |
+
"<|28.80|>": 51804,
|
1093 |
+
"<|28.82|>": 51805,
|
1094 |
+
"<|28.84|>": 51806,
|
1095 |
+
"<|28.86|>": 51807,
|
1096 |
+
"<|28.88|>": 51808,
|
1097 |
+
"<|28.90|>": 51809,
|
1098 |
+
"<|28.92|>": 51810,
|
1099 |
+
"<|28.94|>": 51811,
|
1100 |
+
"<|28.96|>": 51812,
|
1101 |
+
"<|28.98|>": 51813,
|
1102 |
+
"<|29.00|>": 51814,
|
1103 |
+
"<|29.02|>": 51815,
|
1104 |
+
"<|29.04|>": 51816,
|
1105 |
+
"<|29.06|>": 51817,
|
1106 |
+
"<|29.08|>": 51818,
|
1107 |
+
"<|29.10|>": 51819,
|
1108 |
+
"<|29.12|>": 51820,
|
1109 |
+
"<|29.14|>": 51821,
|
1110 |
+
"<|29.16|>": 51822,
|
1111 |
+
"<|29.18|>": 51823,
|
1112 |
+
"<|29.20|>": 51824,
|
1113 |
+
"<|29.22|>": 51825,
|
1114 |
+
"<|29.24|>": 51826,
|
1115 |
+
"<|29.26|>": 51827,
|
1116 |
+
"<|29.28|>": 51828,
|
1117 |
+
"<|29.30|>": 51829,
|
1118 |
+
"<|29.32|>": 51830,
|
1119 |
+
"<|29.34|>": 51831,
|
1120 |
+
"<|29.36|>": 51832,
|
1121 |
+
"<|29.38|>": 51833,
|
1122 |
+
"<|29.40|>": 51834,
|
1123 |
+
"<|29.42|>": 51835,
|
1124 |
+
"<|29.44|>": 51836,
|
1125 |
+
"<|29.46|>": 51837,
|
1126 |
+
"<|29.48|>": 51838,
|
1127 |
+
"<|29.50|>": 51839,
|
1128 |
+
"<|29.52|>": 51840,
|
1129 |
+
"<|29.54|>": 51841,
|
1130 |
+
"<|29.56|>": 51842,
|
1131 |
+
"<|29.58|>": 51843,
|
1132 |
+
"<|29.60|>": 51844,
|
1133 |
+
"<|29.62|>": 51845,
|
1134 |
+
"<|29.64|>": 51846,
|
1135 |
+
"<|29.66|>": 51847,
|
1136 |
+
"<|29.68|>": 51848,
|
1137 |
+
"<|29.70|>": 51849,
|
1138 |
+
"<|29.72|>": 51850,
|
1139 |
+
"<|29.74|>": 51851,
|
1140 |
+
"<|29.76|>": 51852,
|
1141 |
+
"<|29.78|>": 51853,
|
1142 |
+
"<|29.80|>": 51854,
|
1143 |
+
"<|29.82|>": 51855,
|
1144 |
+
"<|29.84|>": 51856,
|
1145 |
+
"<|29.86|>": 51857,
|
1146 |
+
"<|29.88|>": 51858,
|
1147 |
+
"<|29.90|>": 51859,
|
1148 |
+
"<|29.92|>": 51860,
|
1149 |
+
"<|29.94|>": 51861,
|
1150 |
+
"<|29.96|>": 51862,
|
1151 |
+
"<|29.98|>": 51863,
|
1152 |
+
"<|3.00|>": 50514,
|
1153 |
+
"<|3.02|>": 50515,
|
1154 |
+
"<|3.04|>": 50516,
|
1155 |
+
"<|3.06|>": 50517,
|
1156 |
+
"<|3.08|>": 50518,
|
1157 |
+
"<|3.10|>": 50519,
|
1158 |
+
"<|3.12|>": 50520,
|
1159 |
+
"<|3.14|>": 50521,
|
1160 |
+
"<|3.16|>": 50522,
|
1161 |
+
"<|3.18|>": 50523,
|
1162 |
+
"<|3.20|>": 50524,
|
1163 |
+
"<|3.22|>": 50525,
|
1164 |
+
"<|3.24|>": 50526,
|
1165 |
+
"<|3.26|>": 50527,
|
1166 |
+
"<|3.28|>": 50528,
|
1167 |
+
"<|3.30|>": 50529,
|
1168 |
+
"<|3.32|>": 50530,
|
1169 |
+
"<|3.34|>": 50531,
|
1170 |
+
"<|3.36|>": 50532,
|
1171 |
+
"<|3.38|>": 50533,
|
1172 |
+
"<|3.40|>": 50534,
|
1173 |
+
"<|3.42|>": 50535,
|
1174 |
+
"<|3.44|>": 50536,
|
1175 |
+
"<|3.46|>": 50537,
|
1176 |
+
"<|3.48|>": 50538,
|
1177 |
+
"<|3.50|>": 50539,
|
1178 |
+
"<|3.52|>": 50540,
|
1179 |
+
"<|3.54|>": 50541,
|
1180 |
+
"<|3.56|>": 50542,
|
1181 |
+
"<|3.58|>": 50543,
|
1182 |
+
"<|3.60|>": 50544,
|
1183 |
+
"<|3.62|>": 50545,
|
1184 |
+
"<|3.64|>": 50546,
|
1185 |
+
"<|3.66|>": 50547,
|
1186 |
+
"<|3.68|>": 50548,
|
1187 |
+
"<|3.70|>": 50549,
|
1188 |
+
"<|3.72|>": 50550,
|
1189 |
+
"<|3.74|>": 50551,
|
1190 |
+
"<|3.76|>": 50552,
|
1191 |
+
"<|3.78|>": 50553,
|
1192 |
+
"<|3.80|>": 50554,
|
1193 |
+
"<|3.82|>": 50555,
|
1194 |
+
"<|3.84|>": 50556,
|
1195 |
+
"<|3.86|>": 50557,
|
1196 |
+
"<|3.88|>": 50558,
|
1197 |
+
"<|3.90|>": 50559,
|
1198 |
+
"<|3.92|>": 50560,
|
1199 |
+
"<|3.94|>": 50561,
|
1200 |
+
"<|3.96|>": 50562,
|
1201 |
+
"<|3.98|>": 50563,
|
1202 |
+
"<|30.00|>": 51864,
|
1203 |
+
"<|4.00|>": 50564,
|
1204 |
+
"<|4.02|>": 50565,
|
1205 |
+
"<|4.04|>": 50566,
|
1206 |
+
"<|4.06|>": 50567,
|
1207 |
+
"<|4.08|>": 50568,
|
1208 |
+
"<|4.10|>": 50569,
|
1209 |
+
"<|4.12|>": 50570,
|
1210 |
+
"<|4.14|>": 50571,
|
1211 |
+
"<|4.16|>": 50572,
|
1212 |
+
"<|4.18|>": 50573,
|
1213 |
+
"<|4.20|>": 50574,
|
1214 |
+
"<|4.22|>": 50575,
|
1215 |
+
"<|4.24|>": 50576,
|
1216 |
+
"<|4.26|>": 50577,
|
1217 |
+
"<|4.28|>": 50578,
|
1218 |
+
"<|4.30|>": 50579,
|
1219 |
+
"<|4.32|>": 50580,
|
1220 |
+
"<|4.34|>": 50581,
|
1221 |
+
"<|4.36|>": 50582,
|
1222 |
+
"<|4.38|>": 50583,
|
1223 |
+
"<|4.40|>": 50584,
|
1224 |
+
"<|4.42|>": 50585,
|
1225 |
+
"<|4.44|>": 50586,
|
1226 |
+
"<|4.46|>": 50587,
|
1227 |
+
"<|4.48|>": 50588,
|
1228 |
+
"<|4.50|>": 50589,
|
1229 |
+
"<|4.52|>": 50590,
|
1230 |
+
"<|4.54|>": 50591,
|
1231 |
+
"<|4.56|>": 50592,
|
1232 |
+
"<|4.58|>": 50593,
|
1233 |
+
"<|4.60|>": 50594,
|
1234 |
+
"<|4.62|>": 50595,
|
1235 |
+
"<|4.64|>": 50596,
|
1236 |
+
"<|4.66|>": 50597,
|
1237 |
+
"<|4.68|>": 50598,
|
1238 |
+
"<|4.70|>": 50599,
|
1239 |
+
"<|4.72|>": 50600,
|
1240 |
+
"<|4.74|>": 50601,
|
1241 |
+
"<|4.76|>": 50602,
|
1242 |
+
"<|4.78|>": 50603,
|
1243 |
+
"<|4.80|>": 50604,
|
1244 |
+
"<|4.82|>": 50605,
|
1245 |
+
"<|4.84|>": 50606,
|
1246 |
+
"<|4.86|>": 50607,
|
1247 |
+
"<|4.88|>": 50608,
|
1248 |
+
"<|4.90|>": 50609,
|
1249 |
+
"<|4.92|>": 50610,
|
1250 |
+
"<|4.94|>": 50611,
|
1251 |
+
"<|4.96|>": 50612,
|
1252 |
+
"<|4.98|>": 50613,
|
1253 |
+
"<|5.00|>": 50614,
|
1254 |
+
"<|5.02|>": 50615,
|
1255 |
+
"<|5.04|>": 50616,
|
1256 |
+
"<|5.06|>": 50617,
|
1257 |
+
"<|5.08|>": 50618,
|
1258 |
+
"<|5.10|>": 50619,
|
1259 |
+
"<|5.12|>": 50620,
|
1260 |
+
"<|5.14|>": 50621,
|
1261 |
+
"<|5.16|>": 50622,
|
1262 |
+
"<|5.18|>": 50623,
|
1263 |
+
"<|5.20|>": 50624,
|
1264 |
+
"<|5.22|>": 50625,
|
1265 |
+
"<|5.24|>": 50626,
|
1266 |
+
"<|5.26|>": 50627,
|
1267 |
+
"<|5.28|>": 50628,
|
1268 |
+
"<|5.30|>": 50629,
|
1269 |
+
"<|5.32|>": 50630,
|
1270 |
+
"<|5.34|>": 50631,
|
1271 |
+
"<|5.36|>": 50632,
|
1272 |
+
"<|5.38|>": 50633,
|
1273 |
+
"<|5.40|>": 50634,
|
1274 |
+
"<|5.42|>": 50635,
|
1275 |
+
"<|5.44|>": 50636,
|
1276 |
+
"<|5.46|>": 50637,
|
1277 |
+
"<|5.48|>": 50638,
|
1278 |
+
"<|5.50|>": 50639,
|
1279 |
+
"<|5.52|>": 50640,
|
1280 |
+
"<|5.54|>": 50641,
|
1281 |
+
"<|5.56|>": 50642,
|
1282 |
+
"<|5.58|>": 50643,
|
1283 |
+
"<|5.60|>": 50644,
|
1284 |
+
"<|5.62|>": 50645,
|
1285 |
+
"<|5.64|>": 50646,
|
1286 |
+
"<|5.66|>": 50647,
|
1287 |
+
"<|5.68|>": 50648,
|
1288 |
+
"<|5.70|>": 50649,
|
1289 |
+
"<|5.72|>": 50650,
|
1290 |
+
"<|5.74|>": 50651,
|
1291 |
+
"<|5.76|>": 50652,
|
1292 |
+
"<|5.78|>": 50653,
|
1293 |
+
"<|5.80|>": 50654,
|
1294 |
+
"<|5.82|>": 50655,
|
1295 |
+
"<|5.84|>": 50656,
|
1296 |
+
"<|5.86|>": 50657,
|
1297 |
+
"<|5.88|>": 50658,
|
1298 |
+
"<|5.90|>": 50659,
|
1299 |
+
"<|5.92|>": 50660,
|
1300 |
+
"<|5.94|>": 50661,
|
1301 |
+
"<|5.96|>": 50662,
|
1302 |
+
"<|5.98|>": 50663,
|
1303 |
+
"<|6.00|>": 50664,
|
1304 |
+
"<|6.02|>": 50665,
|
1305 |
+
"<|6.04|>": 50666,
|
1306 |
+
"<|6.06|>": 50667,
|
1307 |
+
"<|6.08|>": 50668,
|
1308 |
+
"<|6.10|>": 50669,
|
1309 |
+
"<|6.12|>": 50670,
|
1310 |
+
"<|6.14|>": 50671,
|
1311 |
+
"<|6.16|>": 50672,
|
1312 |
+
"<|6.18|>": 50673,
|
1313 |
+
"<|6.20|>": 50674,
|
1314 |
+
"<|6.22|>": 50675,
|
1315 |
+
"<|6.24|>": 50676,
|
1316 |
+
"<|6.26|>": 50677,
|
1317 |
+
"<|6.28|>": 50678,
|
1318 |
+
"<|6.30|>": 50679,
|
1319 |
+
"<|6.32|>": 50680,
|
1320 |
+
"<|6.34|>": 50681,
|
1321 |
+
"<|6.36|>": 50682,
|
1322 |
+
"<|6.38|>": 50683,
|
1323 |
+
"<|6.40|>": 50684,
|
1324 |
+
"<|6.42|>": 50685,
|
1325 |
+
"<|6.44|>": 50686,
|
1326 |
+
"<|6.46|>": 50687,
|
1327 |
+
"<|6.48|>": 50688,
|
1328 |
+
"<|6.50|>": 50689,
|
1329 |
+
"<|6.52|>": 50690,
|
1330 |
+
"<|6.54|>": 50691,
|
1331 |
+
"<|6.56|>": 50692,
|
1332 |
+
"<|6.58|>": 50693,
|
1333 |
+
"<|6.60|>": 50694,
|
1334 |
+
"<|6.62|>": 50695,
|
1335 |
+
"<|6.64|>": 50696,
|
1336 |
+
"<|6.66|>": 50697,
|
1337 |
+
"<|6.68|>": 50698,
|
1338 |
+
"<|6.70|>": 50699,
|
1339 |
+
"<|6.72|>": 50700,
|
1340 |
+
"<|6.74|>": 50701,
|
1341 |
+
"<|6.76|>": 50702,
|
1342 |
+
"<|6.78|>": 50703,
|
1343 |
+
"<|6.80|>": 50704,
|
1344 |
+
"<|6.82|>": 50705,
|
1345 |
+
"<|6.84|>": 50706,
|
1346 |
+
"<|6.86|>": 50707,
|
1347 |
+
"<|6.88|>": 50708,
|
1348 |
+
"<|6.90|>": 50709,
|
1349 |
+
"<|6.92|>": 50710,
|
1350 |
+
"<|6.94|>": 50711,
|
1351 |
+
"<|6.96|>": 50712,
|
1352 |
+
"<|6.98|>": 50713,
|
1353 |
+
"<|7.00|>": 50714,
|
1354 |
+
"<|7.02|>": 50715,
|
1355 |
+
"<|7.04|>": 50716,
|
1356 |
+
"<|7.06|>": 50717,
|
1357 |
+
"<|7.08|>": 50718,
|
1358 |
+
"<|7.10|>": 50719,
|
1359 |
+
"<|7.12|>": 50720,
|
1360 |
+
"<|7.14|>": 50721,
|
1361 |
+
"<|7.16|>": 50722,
|
1362 |
+
"<|7.18|>": 50723,
|
1363 |
+
"<|7.20|>": 50724,
|
1364 |
+
"<|7.22|>": 50725,
|
1365 |
+
"<|7.24|>": 50726,
|
1366 |
+
"<|7.26|>": 50727,
|
1367 |
+
"<|7.28|>": 50728,
|
1368 |
+
"<|7.30|>": 50729,
|
1369 |
+
"<|7.32|>": 50730,
|
1370 |
+
"<|7.34|>": 50731,
|
1371 |
+
"<|7.36|>": 50732,
|
1372 |
+
"<|7.38|>": 50733,
|
1373 |
+
"<|7.40|>": 50734,
|
1374 |
+
"<|7.42|>": 50735,
|
1375 |
+
"<|7.44|>": 50736,
|
1376 |
+
"<|7.46|>": 50737,
|
1377 |
+
"<|7.48|>": 50738,
|
1378 |
+
"<|7.50|>": 50739,
|
1379 |
+
"<|7.52|>": 50740,
|
1380 |
+
"<|7.54|>": 50741,
|
1381 |
+
"<|7.56|>": 50742,
|
1382 |
+
"<|7.58|>": 50743,
|
1383 |
+
"<|7.60|>": 50744,
|
1384 |
+
"<|7.62|>": 50745,
|
1385 |
+
"<|7.64|>": 50746,
|
1386 |
+
"<|7.66|>": 50747,
|
1387 |
+
"<|7.68|>": 50748,
|
1388 |
+
"<|7.70|>": 50749,
|
1389 |
+
"<|7.72|>": 50750,
|
1390 |
+
"<|7.74|>": 50751,
|
1391 |
+
"<|7.76|>": 50752,
|
1392 |
+
"<|7.78|>": 50753,
|
1393 |
+
"<|7.80|>": 50754,
|
1394 |
+
"<|7.82|>": 50755,
|
1395 |
+
"<|7.84|>": 50756,
|
1396 |
+
"<|7.86|>": 50757,
|
1397 |
+
"<|7.88|>": 50758,
|
1398 |
+
"<|7.90|>": 50759,
|
1399 |
+
"<|7.92|>": 50760,
|
1400 |
+
"<|7.94|>": 50761,
|
1401 |
+
"<|7.96|>": 50762,
|
1402 |
+
"<|7.98|>": 50763,
|
1403 |
+
"<|8.00|>": 50764,
|
1404 |
+
"<|8.02|>": 50765,
|
1405 |
+
"<|8.04|>": 50766,
|
1406 |
+
"<|8.06|>": 50767,
|
1407 |
+
"<|8.08|>": 50768,
|
1408 |
+
"<|8.10|>": 50769,
|
1409 |
+
"<|8.12|>": 50770,
|
1410 |
+
"<|8.14|>": 50771,
|
1411 |
+
"<|8.16|>": 50772,
|
1412 |
+
"<|8.18|>": 50773,
|
1413 |
+
"<|8.20|>": 50774,
|
1414 |
+
"<|8.22|>": 50775,
|
1415 |
+
"<|8.24|>": 50776,
|
1416 |
+
"<|8.26|>": 50777,
|
1417 |
+
"<|8.28|>": 50778,
|
1418 |
+
"<|8.30|>": 50779,
|
1419 |
+
"<|8.32|>": 50780,
|
1420 |
+
"<|8.34|>": 50781,
|
1421 |
+
"<|8.36|>": 50782,
|
1422 |
+
"<|8.38|>": 50783,
|
1423 |
+
"<|8.40|>": 50784,
|
1424 |
+
"<|8.42|>": 50785,
|
1425 |
+
"<|8.44|>": 50786,
|
1426 |
+
"<|8.46|>": 50787,
|
1427 |
+
"<|8.48|>": 50788,
|
1428 |
+
"<|8.50|>": 50789,
|
1429 |
+
"<|8.52|>": 50790,
|
1430 |
+
"<|8.54|>": 50791,
|
1431 |
+
"<|8.56|>": 50792,
|
1432 |
+
"<|8.58|>": 50793,
|
1433 |
+
"<|8.60|>": 50794,
|
1434 |
+
"<|8.62|>": 50795,
|
1435 |
+
"<|8.64|>": 50796,
|
1436 |
+
"<|8.66|>": 50797,
|
1437 |
+
"<|8.68|>": 50798,
|
1438 |
+
"<|8.70|>": 50799,
|
1439 |
+
"<|8.72|>": 50800,
|
1440 |
+
"<|8.74|>": 50801,
|
1441 |
+
"<|8.76|>": 50802,
|
1442 |
+
"<|8.78|>": 50803,
|
1443 |
+
"<|8.80|>": 50804,
|
1444 |
+
"<|8.82|>": 50805,
|
1445 |
+
"<|8.84|>": 50806,
|
1446 |
+
"<|8.86|>": 50807,
|
1447 |
+
"<|8.88|>": 50808,
|
1448 |
+
"<|8.90|>": 50809,
|
1449 |
+
"<|8.92|>": 50810,
|
1450 |
+
"<|8.94|>": 50811,
|
1451 |
+
"<|8.96|>": 50812,
|
1452 |
+
"<|8.98|>": 50813,
|
1453 |
+
"<|9.00|>": 50814,
|
1454 |
+
"<|9.02|>": 50815,
|
1455 |
+
"<|9.04|>": 50816,
|
1456 |
+
"<|9.06|>": 50817,
|
1457 |
+
"<|9.08|>": 50818,
|
1458 |
+
"<|9.10|>": 50819,
|
1459 |
+
"<|9.12|>": 50820,
|
1460 |
+
"<|9.14|>": 50821,
|
1461 |
+
"<|9.16|>": 50822,
|
1462 |
+
"<|9.18|>": 50823,
|
1463 |
+
"<|9.20|>": 50824,
|
1464 |
+
"<|9.22|>": 50825,
|
1465 |
+
"<|9.24|>": 50826,
|
1466 |
+
"<|9.26|>": 50827,
|
1467 |
+
"<|9.28|>": 50828,
|
1468 |
+
"<|9.30|>": 50829,
|
1469 |
+
"<|9.32|>": 50830,
|
1470 |
+
"<|9.34|>": 50831,
|
1471 |
+
"<|9.36|>": 50832,
|
1472 |
+
"<|9.38|>": 50833,
|
1473 |
+
"<|9.40|>": 50834,
|
1474 |
+
"<|9.42|>": 50835,
|
1475 |
+
"<|9.44|>": 50836,
|
1476 |
+
"<|9.46|>": 50837,
|
1477 |
+
"<|9.48|>": 50838,
|
1478 |
+
"<|9.50|>": 50839,
|
1479 |
+
"<|9.52|>": 50840,
|
1480 |
+
"<|9.54|>": 50841,
|
1481 |
+
"<|9.56|>": 50842,
|
1482 |
+
"<|9.58|>": 50843,
|
1483 |
+
"<|9.60|>": 50844,
|
1484 |
+
"<|9.62|>": 50845,
|
1485 |
+
"<|9.64|>": 50846,
|
1486 |
+
"<|9.66|>": 50847,
|
1487 |
+
"<|9.68|>": 50848,
|
1488 |
+
"<|9.70|>": 50849,
|
1489 |
+
"<|9.72|>": 50850,
|
1490 |
+
"<|9.74|>": 50851,
|
1491 |
+
"<|9.76|>": 50852,
|
1492 |
+
"<|9.78|>": 50853,
|
1493 |
+
"<|9.80|>": 50854,
|
1494 |
+
"<|9.82|>": 50855,
|
1495 |
+
"<|9.84|>": 50856,
|
1496 |
+
"<|9.86|>": 50857,
|
1497 |
+
"<|9.88|>": 50858,
|
1498 |
+
"<|9.90|>": 50859,
|
1499 |
+
"<|9.92|>": 50860,
|
1500 |
+
"<|9.94|>": 50861,
|
1501 |
+
"<|9.96|>": 50862,
|
1502 |
+
"<|9.98|>": 50863,
|
1503 |
+
"<|af|>": 50327,
|
1504 |
+
"<|am|>": 50334,
|
1505 |
+
"<|ar|>": 50272,
|
1506 |
+
"<|as|>": 50350,
|
1507 |
+
"<|az|>": 50304,
|
1508 |
+
"<|ba|>": 50355,
|
1509 |
+
"<|be|>": 50330,
|
1510 |
+
"<|bg|>": 50292,
|
1511 |
+
"<|bn|>": 50302,
|
1512 |
+
"<|bo|>": 50347,
|
1513 |
+
"<|br|>": 50309,
|
1514 |
+
"<|bs|>": 50315,
|
1515 |
+
"<|ca|>": 50270,
|
1516 |
+
"<|cs|>": 50283,
|
1517 |
+
"<|cy|>": 50297,
|
1518 |
+
"<|da|>": 50285,
|
1519 |
+
"<|de|>": 50261,
|
1520 |
+
"<|el|>": 50281,
|
1521 |
+
"<|en|>": 50259,
|
1522 |
+
"<|es|>": 50262,
|
1523 |
+
"<|et|>": 50307,
|
1524 |
+
"<|eu|>": 50310,
|
1525 |
+
"<|fa|>": 50300,
|
1526 |
+
"<|fi|>": 50277,
|
1527 |
+
"<|fo|>": 50338,
|
1528 |
+
"<|fr|>": 50265,
|
1529 |
+
"<|gl|>": 50319,
|
1530 |
+
"<|gu|>": 50333,
|
1531 |
+
"<|haw|>": 50352,
|
1532 |
+
"<|ha|>": 50354,
|
1533 |
+
"<|he|>": 50279,
|
1534 |
+
"<|hi|>": 50276,
|
1535 |
+
"<|hr|>": 50291,
|
1536 |
+
"<|ht|>": 50339,
|
1537 |
+
"<|hu|>": 50286,
|
1538 |
+
"<|hy|>": 50312,
|
1539 |
+
"<|id|>": 50275,
|
1540 |
+
"<|is|>": 50311,
|
1541 |
+
"<|it|>": 50274,
|
1542 |
+
"<|ja|>": 50266,
|
1543 |
+
"<|jw|>": 50356,
|
1544 |
+
"<|ka|>": 50329,
|
1545 |
+
"<|kk|>": 50316,
|
1546 |
+
"<|km|>": 50323,
|
1547 |
+
"<|kn|>": 50306,
|
1548 |
+
"<|ko|>": 50264,
|
1549 |
+
"<|la|>": 50294,
|
1550 |
+
"<|lb|>": 50345,
|
1551 |
+
"<|ln|>": 50353,
|
1552 |
+
"<|lo|>": 50336,
|
1553 |
+
"<|lt|>": 50293,
|
1554 |
+
"<|lv|>": 50301,
|
1555 |
+
"<|mg|>": 50349,
|
1556 |
+
"<|mi|>": 50295,
|
1557 |
+
"<|mk|>": 50308,
|
1558 |
+
"<|ml|>": 50296,
|
1559 |
+
"<|mn|>": 50314,
|
1560 |
+
"<|mr|>": 50320,
|
1561 |
+
"<|ms|>": 50282,
|
1562 |
+
"<|mt|>": 50343,
|
1563 |
+
"<|my|>": 50346,
|
1564 |
+
"<|ne|>": 50313,
|
1565 |
+
"<|nl|>": 50271,
|
1566 |
+
"<|nn|>": 50342,
|
1567 |
+
"<|nocaptions|>": 50362,
|
1568 |
+
"<|notimestamps|>": 50363,
|
1569 |
+
"<|no|>": 50288,
|
1570 |
+
"<|oc|>": 50328,
|
1571 |
+
"<|pa|>": 50321,
|
1572 |
+
"<|pl|>": 50269,
|
1573 |
+
"<|ps|>": 50340,
|
1574 |
+
"<|pt|>": 50267,
|
1575 |
+
"<|ro|>": 50284,
|
1576 |
+
"<|ru|>": 50263,
|
1577 |
+
"<|sa|>": 50344,
|
1578 |
+
"<|sd|>": 50332,
|
1579 |
+
"<|si|>": 50322,
|
1580 |
+
"<|sk|>": 50298,
|
1581 |
+
"<|sl|>": 50305,
|
1582 |
+
"<|sn|>": 50324,
|
1583 |
+
"<|so|>": 50326,
|
1584 |
+
"<|sq|>": 50317,
|
1585 |
+
"<|sr|>": 50303,
|
1586 |
+
"<|startoflm|>": 50360,
|
1587 |
+
"<|startofprev|>": 50361,
|
1588 |
+
"<|startoftranscript|>": 50258,
|
1589 |
+
"<|su|>": 50357,
|
1590 |
+
"<|sv|>": 50273,
|
1591 |
+
"<|sw|>": 50318,
|
1592 |
+
"<|ta|>": 50287,
|
1593 |
+
"<|te|>": 50299,
|
1594 |
+
"<|tg|>": 50331,
|
1595 |
+
"<|th|>": 50289,
|
1596 |
+
"<|tk|>": 50341,
|
1597 |
+
"<|tl|>": 50348,
|
1598 |
+
"<|transcribe|>": 50359,
|
1599 |
+
"<|translate|>": 50358,
|
1600 |
+
"<|tr|>": 50268,
|
1601 |
+
"<|tt|>": 50351,
|
1602 |
+
"<|uk|>": 50280,
|
1603 |
+
"<|ur|>": 50290,
|
1604 |
+
"<|uz|>": 50337,
|
1605 |
+
"<|vi|>": 50278,
|
1606 |
+
"<|yi|>": 50335,
|
1607 |
+
"<|yo|>": 50325,
|
1608 |
+
"<|zh|>": 50260
|
1609 |
+
}
|
config.json
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "openai/whisper-large-v2",
|
3 |
+
"activation_dropout": 0.0,
|
4 |
+
"activation_function": "gelu",
|
5 |
+
"apply_spec_augment": false,
|
6 |
+
"architectures": [
|
7 |
+
"WhisperForConditionalGeneration"
|
8 |
+
],
|
9 |
+
"attention_dropout": 0.0,
|
10 |
+
"begin_suppress_tokens": [
|
11 |
+
220,
|
12 |
+
50257
|
13 |
+
],
|
14 |
+
"bos_token_id": 50257,
|
15 |
+
"classifier_proj_size": 256,
|
16 |
+
"d_model": 1280,
|
17 |
+
"decoder_attention_heads": 20,
|
18 |
+
"decoder_ffn_dim": 5120,
|
19 |
+
"decoder_layerdrop": 0.0,
|
20 |
+
"decoder_layers": 32,
|
21 |
+
"decoder_start_token_id": 50258,
|
22 |
+
"dropout": 0.0,
|
23 |
+
"encoder_attention_heads": 20,
|
24 |
+
"encoder_ffn_dim": 5120,
|
25 |
+
"encoder_layerdrop": 0.0,
|
26 |
+
"encoder_layers": 32,
|
27 |
+
"eos_token_id": 50257,
|
28 |
+
"forced_decoder_ids": [
|
29 |
+
[
|
30 |
+
1,
|
31 |
+
50259
|
32 |
+
],
|
33 |
+
[
|
34 |
+
2,
|
35 |
+
50359
|
36 |
+
],
|
37 |
+
[
|
38 |
+
3,
|
39 |
+
50363
|
40 |
+
]
|
41 |
+
],
|
42 |
+
"init_std": 0.02,
|
43 |
+
"is_encoder_decoder": true,
|
44 |
+
"mask_feature_length": 10,
|
45 |
+
"mask_feature_min_masks": 0,
|
46 |
+
"mask_feature_prob": 0.0,
|
47 |
+
"mask_time_length": 10,
|
48 |
+
"mask_time_min_masks": 2,
|
49 |
+
"mask_time_prob": 0.05,
|
50 |
+
"max_length": 448,
|
51 |
+
"max_source_positions": 1500,
|
52 |
+
"max_target_positions": 448,
|
53 |
+
"median_filter_width": 7,
|
54 |
+
"model_type": "whisper",
|
55 |
+
"num_hidden_layers": 32,
|
56 |
+
"num_mel_bins": 80,
|
57 |
+
"pad_token_id": 50257,
|
58 |
+
"scale_embedding": false,
|
59 |
+
"suppress_tokens": [
|
60 |
+
1,
|
61 |
+
2,
|
62 |
+
7,
|
63 |
+
8,
|
64 |
+
9,
|
65 |
+
10,
|
66 |
+
14,
|
67 |
+
25,
|
68 |
+
26,
|
69 |
+
27,
|
70 |
+
28,
|
71 |
+
29,
|
72 |
+
31,
|
73 |
+
58,
|
74 |
+
59,
|
75 |
+
60,
|
76 |
+
61,
|
77 |
+
62,
|
78 |
+
63,
|
79 |
+
90,
|
80 |
+
91,
|
81 |
+
92,
|
82 |
+
93,
|
83 |
+
359,
|
84 |
+
503,
|
85 |
+
522,
|
86 |
+
542,
|
87 |
+
873,
|
88 |
+
893,
|
89 |
+
902,
|
90 |
+
918,
|
91 |
+
922,
|
92 |
+
931,
|
93 |
+
1350,
|
94 |
+
1853,
|
95 |
+
1982,
|
96 |
+
2460,
|
97 |
+
2627,
|
98 |
+
3246,
|
99 |
+
3253,
|
100 |
+
3268,
|
101 |
+
3536,
|
102 |
+
3846,
|
103 |
+
3961,
|
104 |
+
4183,
|
105 |
+
4667,
|
106 |
+
6585,
|
107 |
+
6647,
|
108 |
+
7273,
|
109 |
+
9061,
|
110 |
+
9383,
|
111 |
+
10428,
|
112 |
+
10929,
|
113 |
+
11938,
|
114 |
+
12033,
|
115 |
+
12331,
|
116 |
+
12562,
|
117 |
+
13793,
|
118 |
+
14157,
|
119 |
+
14635,
|
120 |
+
15265,
|
121 |
+
15618,
|
122 |
+
16553,
|
123 |
+
16604,
|
124 |
+
18362,
|
125 |
+
18956,
|
126 |
+
20075,
|
127 |
+
21675,
|
128 |
+
22520,
|
129 |
+
26130,
|
130 |
+
26161,
|
131 |
+
26435,
|
132 |
+
28279,
|
133 |
+
29464,
|
134 |
+
31650,
|
135 |
+
32302,
|
136 |
+
32470,
|
137 |
+
36865,
|
138 |
+
42863,
|
139 |
+
47425,
|
140 |
+
49870,
|
141 |
+
50254,
|
142 |
+
50258,
|
143 |
+
50358,
|
144 |
+
50359,
|
145 |
+
50360,
|
146 |
+
50361,
|
147 |
+
50362
|
148 |
+
],
|
149 |
+
"torch_dtype": "float32",
|
150 |
+
"transformers_version": "4.34.0.dev0",
|
151 |
+
"use_cache": true,
|
152 |
+
"use_weighted_layer_sum": false,
|
153 |
+
"vocab_size": 51865
|
154 |
+
}
|
distil_whisper/__init__.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
__version__ = "0.0.1"
|
17 |
+
|
18 |
+
from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
|
19 |
+
from .partitioner import PjitPartitioner
|
20 |
+
from .pipeline import FlaxWhisperPipeline
|
21 |
+
from .train_state import InferenceState
|
distil_whisper/__pycache__/__init__.cpython-310.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af9ae4a8a1ff7fb46887a99add169dda2308b6f168ae7176a9a9b934b7d6ecc5
|
3 |
+
size 414
|
distil_whisper/__pycache__/layers.cpython-310.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0bd814358d750f7e74a787df4c376919c74beae05b93e900549940669fde1da4
|
3 |
+
size 41875
|
distil_whisper/__pycache__/modeling_flax_whisper.cpython-310.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6442b8035c452676ffda775161984de1aa5a6a279403bdd258e100567f796d49
|
3 |
+
size 54030
|
distil_whisper/__pycache__/partitioner.cpython-310.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e1027212ff84b04b61fe4a3a5fa194e076374a58ab2fb96a38c8328319b21a3
|
3 |
+
size 33252
|
distil_whisper/__pycache__/pipeline.cpython-310.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b034728c1456114b4f1f484421a9d960df0533d30bbd315e7951aa32401c87bf
|
3 |
+
size 16762
|
distil_whisper/__pycache__/train_state.cpython-310.pyc
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:178817b09c84d935eadab40cace6d0fb677871eac5f0dd403e907e555bf8a12e
|
3 |
+
size 4106
|
distil_whisper/layers.py
ADDED
@@ -0,0 +1,1338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Dense attention classes and mask/weighting functions."""
|
16 |
+
|
17 |
+
# pylint: disable=attribute-defined-outside-init,g-bare-generic
|
18 |
+
|
19 |
+
import dataclasses
|
20 |
+
import functools
|
21 |
+
import operator
|
22 |
+
from typing import Any, Callable, Iterable, List, Optional, Sequence, Tuple, Union
|
23 |
+
|
24 |
+
import jax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
import numpy as np
|
27 |
+
from flax import linen as nn
|
28 |
+
from flax.linen import partitioning as nn_partitioning
|
29 |
+
from flax.linen.dtypes import promote_dtype
|
30 |
+
from jax import lax, random
|
31 |
+
|
32 |
+
|
33 |
+
# from flax.linen.partitioning import param_with_axes, with_sharding_constraint
|
34 |
+
param_with_axes = nn_partitioning.param_with_axes
|
35 |
+
with_sharding_constraint = nn_partitioning.with_sharding_constraint
|
36 |
+
|
37 |
+
|
38 |
+
# Type annotations
|
39 |
+
Array = jnp.ndarray
|
40 |
+
DType = jnp.dtype
|
41 |
+
PRNGKey = jnp.ndarray
|
42 |
+
Shape = Iterable[int]
|
43 |
+
Activation = Callable[..., Array]
|
44 |
+
PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], Tuple[lax.Precision, lax.Precision]]
|
45 |
+
DotGeneralT = Callable[..., Array]
|
46 |
+
ConvGeneralDilatedT = Callable[..., Array]
|
47 |
+
PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]]
|
48 |
+
LaxPadding = Union[str, Sequence[Tuple[int, int]]]
|
49 |
+
|
50 |
+
# Parameter initializers.
|
51 |
+
Initializer = Callable[[PRNGKey, Shape, DType], Array]
|
52 |
+
InitializerAxis = Union[int, Tuple[int, ...]]
|
53 |
+
NdInitializer = Callable[[PRNGKey, Shape, DType, InitializerAxis, InitializerAxis], Array]
|
54 |
+
|
55 |
+
default_embed_init = nn.initializers.variance_scaling(1.0, "fan_in", "normal", out_axis=0)
|
56 |
+
|
57 |
+
|
58 |
+
# ------------------------------------------------------------------------------
|
59 |
+
# Temporary inlined JAX N-d initializer code
|
60 |
+
# TODO(levskaya): remove once new JAX release is out.
|
61 |
+
# ------------------------------------------------------------------------------
|
62 |
+
def _compute_fans(shape: jax.core.NamedShape, in_axis=-2, out_axis=-1):
|
63 |
+
"""Inlined JAX `nn.initializer._compute_fans`."""
|
64 |
+
if isinstance(in_axis, int):
|
65 |
+
in_size = shape[in_axis]
|
66 |
+
else:
|
67 |
+
in_size = int(np.prod([shape[i] for i in in_axis]))
|
68 |
+
if isinstance(out_axis, int):
|
69 |
+
out_size = shape[out_axis]
|
70 |
+
else:
|
71 |
+
out_size = int(np.prod([shape[i] for i in out_axis]))
|
72 |
+
receptive_field_size = shape.total / in_size / out_size
|
73 |
+
fan_in = in_size * receptive_field_size
|
74 |
+
fan_out = out_size * receptive_field_size
|
75 |
+
return fan_in, fan_out
|
76 |
+
|
77 |
+
|
78 |
+
def variance_scaling(scale, mode, distribution, in_axis=-2, out_axis=-1, dtype=jnp.float_):
|
79 |
+
"""Inlined JAX `nn.initializer.variance_scaling`."""
|
80 |
+
|
81 |
+
def init(key, shape, dtype=dtype):
|
82 |
+
return jnp.zeros(shape, dtype=dtype)
|
83 |
+
dtype = jax.dtypes.canonicalize_dtype(dtype)
|
84 |
+
shape = jax.core.as_named_shape(shape)
|
85 |
+
fan_in, fan_out = _compute_fans(shape, in_axis, out_axis)
|
86 |
+
if mode == "fan_in":
|
87 |
+
denominator = fan_in
|
88 |
+
elif mode == "fan_out":
|
89 |
+
denominator = fan_out
|
90 |
+
elif mode == "fan_avg":
|
91 |
+
denominator = (fan_in + fan_out) / 2
|
92 |
+
else:
|
93 |
+
raise ValueError("invalid mode for variance scaling initializer: {}".format(mode))
|
94 |
+
variance = jnp.array(scale / denominator, dtype=dtype)
|
95 |
+
|
96 |
+
if distribution == "truncated_normal":
|
97 |
+
# constant is stddev of standard normal truncated to (-2, 2)
|
98 |
+
stddev = jnp.sqrt(variance) / jnp.array(0.87962566103423978, dtype)
|
99 |
+
return random.truncated_normal(key, -2, 2, shape, dtype) * stddev
|
100 |
+
elif distribution == "normal":
|
101 |
+
return random.normal(key, shape, dtype) * jnp.sqrt(variance)
|
102 |
+
elif distribution == "uniform":
|
103 |
+
return random.uniform(key, shape, dtype, -1) * jnp.sqrt(3 * variance)
|
104 |
+
else:
|
105 |
+
raise ValueError("invalid distribution for variance scaling initializer: {}".format(distribution))
|
106 |
+
|
107 |
+
return init
|
108 |
+
|
109 |
+
|
110 |
+
# ------------------------------------------------------------------------------
|
111 |
+
|
112 |
+
|
113 |
+
def nd_dense_init(scale, mode, distribution):
|
114 |
+
"""Initializer with in_axis, out_axis set at call time."""
|
115 |
+
|
116 |
+
def init_fn(key, shape, dtype, in_axis, out_axis):
|
117 |
+
fn = variance_scaling(scale, mode, distribution, in_axis, out_axis)
|
118 |
+
return fn(key, shape, dtype)
|
119 |
+
|
120 |
+
return init_fn
|
121 |
+
|
122 |
+
|
123 |
+
def dot_product_attention(
|
124 |
+
query: Array,
|
125 |
+
key: Array,
|
126 |
+
value: Array,
|
127 |
+
bias: Optional[Array] = None,
|
128 |
+
dropout_rng: Optional[PRNGKey] = None,
|
129 |
+
dropout_rate: float = 0.0,
|
130 |
+
deterministic: bool = False,
|
131 |
+
dtype: DType = jnp.float32,
|
132 |
+
float32_logits: bool = False,
|
133 |
+
):
|
134 |
+
"""Computes dot-product attention given query, key, and value.
|
135 |
+
|
136 |
+
This is the core function for applying attention based on
|
137 |
+
https://arxiv.org/abs/1706.03762. It calculates the attention weights given
|
138 |
+
query and key and combines the values using the attention weights.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
query: queries for calculating attention with shape of `[batch, q_length,
|
142 |
+
num_heads, qk_depth_per_head]`.
|
143 |
+
key: keys for calculating attention with shape of `[batch, kv_length,
|
144 |
+
num_heads, qk_depth_per_head]`.
|
145 |
+
value: values to be used in attention with shape of `[batch, kv_length,
|
146 |
+
num_heads, v_depth_per_head]`.
|
147 |
+
bias: bias for the attention weights. This should be broadcastable to the
|
148 |
+
shape `[batch, num_heads, q_length, kv_length]` This can be used for
|
149 |
+
incorporating causal masks, padding masks, proximity bias, etc.
|
150 |
+
dropout_rng: JAX PRNGKey: to be used for dropout
|
151 |
+
dropout_rate: dropout rate
|
152 |
+
deterministic: bool, deterministic or not (to apply dropout)
|
153 |
+
dtype: the dtype of the computation (default: float32)
|
154 |
+
float32_logits: bool, if True then compute logits in float32 to avoid
|
155 |
+
numerical issues with bfloat16.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
Output of shape `[batch, length, num_heads, v_depth_per_head]`.
|
159 |
+
"""
|
160 |
+
assert key.ndim == query.ndim == value.ndim, "q, k, v must have same rank."
|
161 |
+
assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], "q, k, v batch dims must match."
|
162 |
+
assert query.shape[-2] == key.shape[-2] == value.shape[-2], "q, k, v num_heads must match."
|
163 |
+
assert key.shape[-3] == value.shape[-3], "k, v lengths must match."
|
164 |
+
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
|
165 |
+
|
166 |
+
# Casting logits and softmax computation for float32 for model stability.
|
167 |
+
if float32_logits:
|
168 |
+
query = query.astype(jnp.float32)
|
169 |
+
key = key.astype(jnp.float32)
|
170 |
+
|
171 |
+
# `attn_weights`: [batch, num_heads, q_length, kv_length]
|
172 |
+
attn_weights = jnp.einsum("bqhd,bkhd->bhqk", query, key)
|
173 |
+
|
174 |
+
# Apply attention bias: masking, dropout, proximity bias, etc.
|
175 |
+
if bias is not None:
|
176 |
+
attn_weights = attn_weights + bias.astype(attn_weights.dtype)
|
177 |
+
|
178 |
+
# Normalize the attention weights across `kv_length` dimension.
|
179 |
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
180 |
+
|
181 |
+
# Apply attention dropout.
|
182 |
+
if not deterministic and dropout_rate > 0.0:
|
183 |
+
keep_prob = 1.0 - dropout_rate
|
184 |
+
# T5 broadcasts along the "length" dim, but unclear which one that
|
185 |
+
# corresponds to in positional dimensions here, assuming query dim.
|
186 |
+
dropout_shape = list(attn_weights.shape)
|
187 |
+
dropout_shape[-2] = 1
|
188 |
+
keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
189 |
+
keep = jnp.broadcast_to(keep, attn_weights.shape)
|
190 |
+
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
191 |
+
attn_weights = attn_weights * multiplier
|
192 |
+
|
193 |
+
# Take the linear combination of `value`.
|
194 |
+
return jnp.einsum("bhqk,bkhd->bqhd", attn_weights, value)
|
195 |
+
|
196 |
+
|
197 |
+
dynamic_vector_slice_in_dim = jax.vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
|
198 |
+
|
199 |
+
|
200 |
+
class MultiHeadDotProductAttention(nn.Module):
|
201 |
+
"""Multi-head dot-product attention.
|
202 |
+
|
203 |
+
Attributes:
|
204 |
+
num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
|
205 |
+
should be divisible by the number of heads.
|
206 |
+
head_dim: dimension of each head.
|
207 |
+
dtype: the dtype of the computation.
|
208 |
+
dropout_rate: dropout rate
|
209 |
+
kernel_init: initializer for the kernel of the Dense layers.
|
210 |
+
float32_logits: bool, if True then compute logits in float32 to avoid
|
211 |
+
numerical issues with bfloat16.
|
212 |
+
"""
|
213 |
+
|
214 |
+
num_heads: int
|
215 |
+
head_dim: int
|
216 |
+
dtype: DType = jnp.float32
|
217 |
+
dropout_rate: float = 0.0
|
218 |
+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
|
219 |
+
float32_logits: bool = False # computes logits in float32 for stability.
|
220 |
+
|
221 |
+
@nn.compact
|
222 |
+
def __call__(
|
223 |
+
self,
|
224 |
+
inputs_q: Array,
|
225 |
+
inputs_kv: Array,
|
226 |
+
mask: Optional[Array] = None,
|
227 |
+
bias: Optional[Array] = None,
|
228 |
+
*,
|
229 |
+
decode: bool = False,
|
230 |
+
deterministic: bool = False,
|
231 |
+
) -> Array:
|
232 |
+
"""Applies multi-head dot product attention on the input data.
|
233 |
+
|
234 |
+
Projects the inputs into multi-headed query, key, and value vectors,
|
235 |
+
applies dot-product attention and project the results to an output vector.
|
236 |
+
|
237 |
+
There are two modes: decoding and non-decoding (e.g., training). The mode is
|
238 |
+
determined by `decode` argument. For decoding, this method is called twice,
|
239 |
+
first to initialize the cache and then for an actual decoding process. The
|
240 |
+
two calls are differentiated by the presence of 'cached_key' in the variable
|
241 |
+
dict. In the cache initialization stage, the cache variables are initialized
|
242 |
+
as zeros and will be filled in the subsequent decoding process.
|
243 |
+
|
244 |
+
In the cache initialization call, `inputs_q` has a shape [batch, length,
|
245 |
+
q_features] and `inputs_kv`: [batch, length, kv_features]. During the
|
246 |
+
incremental decoding stage, query, key and value all have the shape [batch,
|
247 |
+
1, qkv_features] corresponding to a single step.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
inputs_q: input queries of shape `[batch, q_length, q_features]`.
|
251 |
+
inputs_kv: key/values of shape `[batch, kv_length, kv_features]`.
|
252 |
+
mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`.
|
253 |
+
bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`.
|
254 |
+
decode: Whether to prepare and use an autoregressive cache.
|
255 |
+
deterministic: Disables dropout if set to True.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
output of shape `[batch, length, q_features]`.
|
259 |
+
"""
|
260 |
+
projection = functools.partial(
|
261 |
+
DenseGeneral,
|
262 |
+
axis=-1,
|
263 |
+
features=(self.num_heads, self.head_dim),
|
264 |
+
kernel_axes=("embed", "heads", "kv"),
|
265 |
+
dtype=self.dtype,
|
266 |
+
)
|
267 |
+
|
268 |
+
# NOTE: T5 does not explicitly rescale the attention logits by
|
269 |
+
# 1/sqrt(depth_kq)! This is folded into the initializers of the
|
270 |
+
# linear transformations, which is equivalent under Adafactor.
|
271 |
+
depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype)
|
272 |
+
|
273 |
+
def query_init(*args):
|
274 |
+
return self.kernel_init(*args) / depth_scaling
|
275 |
+
|
276 |
+
# Project inputs_q to multi-headed q/k/v
|
277 |
+
# dimensions are then [batch, length, num_heads, head_dim]
|
278 |
+
query = projection(kernel_init=query_init, name="query")(inputs_q)
|
279 |
+
key = projection(kernel_init=self.kernel_init, name="key")(inputs_kv)
|
280 |
+
value = projection(kernel_init=self.kernel_init, name="value")(inputs_kv)
|
281 |
+
|
282 |
+
query = with_sharding_constraint(query, ("batch", "length", "heads", "kv"))
|
283 |
+
key = with_sharding_constraint(key, ("batch", "length", "heads", "kv"))
|
284 |
+
value = with_sharding_constraint(value, ("batch", "length", "heads", "kv"))
|
285 |
+
|
286 |
+
if decode:
|
287 |
+
# Detect if we're initializing by absence of existing cache data.
|
288 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
289 |
+
|
290 |
+
# The key and value have dimension [batch, length, num_heads, head_dim],
|
291 |
+
# but we cache them as [batch, num_heads, head_dim, length] as a TPU
|
292 |
+
# fusion optimization. This also enables the "scatter via one-hot
|
293 |
+
# broadcast" trick, which means we do a one-hot broadcast instead of a
|
294 |
+
# scatter/gather operations, resulting in a 3-4x speedup in practice.
|
295 |
+
def swap_dims(x):
|
296 |
+
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
|
297 |
+
|
298 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
|
299 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
|
300 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
301 |
+
if is_initialized:
|
302 |
+
batch, num_heads, head_dim, length = cached_key.value.shape
|
303 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
304 |
+
# and cache the keys and values step by step.
|
305 |
+
# Sanity shape check of cached key against input query.
|
306 |
+
expected_shape = (batch, 1, num_heads, head_dim)
|
307 |
+
if expected_shape != query.shape:
|
308 |
+
raise ValueError(
|
309 |
+
"Autoregressive cache shape error, "
|
310 |
+
"expected query shape %s instead got %s." % (expected_shape, query.shape)
|
311 |
+
)
|
312 |
+
|
313 |
+
# Create a OHE of the current index. NOTE: the index is increased below.
|
314 |
+
cur_index = cache_index.value
|
315 |
+
one_hot_indices = jax.nn.one_hot(cur_index, length, dtype=key.dtype)
|
316 |
+
# In order to update the key, value caches with the current key and
|
317 |
+
# value, we move the length axis to the back, similar to what we did for
|
318 |
+
# the cached ones above.
|
319 |
+
# Note these are currently the key and value of a single position, since
|
320 |
+
# we feed one position at a time.
|
321 |
+
one_token_key = jnp.moveaxis(key, -3, -1)
|
322 |
+
one_token_value = jnp.moveaxis(value, -3, -1)
|
323 |
+
# Update key, value caches with our new 1d spatial slices.
|
324 |
+
# We implement an efficient scatter into the cache via one-hot
|
325 |
+
# broadcast and addition.
|
326 |
+
key = cached_key.value + one_token_key * one_hot_indices
|
327 |
+
value = cached_value.value + one_token_value * one_hot_indices
|
328 |
+
cached_key.value = key
|
329 |
+
cached_value.value = value
|
330 |
+
cache_index.value = cache_index.value + 1
|
331 |
+
# Move the keys and values back to their original shapes.
|
332 |
+
key = jnp.moveaxis(key, -1, -3)
|
333 |
+
value = jnp.moveaxis(value, -1, -3)
|
334 |
+
|
335 |
+
# Causal mask for cached decoder self-attention: our single query
|
336 |
+
# position should only attend to those key positions that have already
|
337 |
+
# been generated and cached, not the remaining zero elements.
|
338 |
+
mask = combine_masks(
|
339 |
+
mask,
|
340 |
+
jnp.broadcast_to(
|
341 |
+
jnp.arange(length) <= cur_index,
|
342 |
+
# (1, 1, length) represent (head dim, query length, key length)
|
343 |
+
# query length is 1 because during decoding we deal with one
|
344 |
+
# index.
|
345 |
+
# The same mask is applied to all batch elements and heads.
|
346 |
+
(batch, 1, 1, length),
|
347 |
+
),
|
348 |
+
)
|
349 |
+
|
350 |
+
# Grab the correct relative attention bias during decoding. This is
|
351 |
+
# only required during single step decoding.
|
352 |
+
if bias is not None:
|
353 |
+
# The bias is a full attention matrix, but during decoding we only
|
354 |
+
# have to take a slice of it.
|
355 |
+
# This is equivalent to bias[..., cur_index:cur_index+1, :].
|
356 |
+
bias = dynamic_vector_slice_in_dim(jnp.squeeze(bias, axis=0), jnp.reshape(cur_index, (-1)), 1, -2)
|
357 |
+
|
358 |
+
# Convert the boolean attention mask to an attention bias.
|
359 |
+
if mask is not None:
|
360 |
+
# attention mask in the form of attention bias
|
361 |
+
attention_bias = lax.select(
|
362 |
+
mask > 0,
|
363 |
+
jnp.full(mask.shape, 0.0).astype(self.dtype),
|
364 |
+
jnp.full(mask.shape, -1e10).astype(self.dtype),
|
365 |
+
)
|
366 |
+
else:
|
367 |
+
attention_bias = None
|
368 |
+
|
369 |
+
# Add provided bias term (e.g. relative position embedding).
|
370 |
+
if bias is not None:
|
371 |
+
attention_bias = combine_biases(attention_bias, bias)
|
372 |
+
|
373 |
+
dropout_rng = None
|
374 |
+
if not deterministic and self.dropout_rate > 0.0:
|
375 |
+
dropout_rng = self.make_rng("dropout")
|
376 |
+
|
377 |
+
# Apply attention.
|
378 |
+
x = dot_product_attention(
|
379 |
+
query,
|
380 |
+
key,
|
381 |
+
value,
|
382 |
+
bias=attention_bias,
|
383 |
+
dropout_rng=dropout_rng,
|
384 |
+
dropout_rate=self.dropout_rate,
|
385 |
+
deterministic=deterministic,
|
386 |
+
dtype=self.dtype,
|
387 |
+
float32_logits=self.float32_logits,
|
388 |
+
)
|
389 |
+
|
390 |
+
# Back to the original inputs dimensions.
|
391 |
+
out = DenseGeneral(
|
392 |
+
features=inputs_q.shape[-1], # output dim is set to the input dim.
|
393 |
+
axis=(-2, -1),
|
394 |
+
kernel_init=self.kernel_init,
|
395 |
+
kernel_axes=("heads", "kv", "embed"),
|
396 |
+
dtype=self.dtype,
|
397 |
+
name="out",
|
398 |
+
)(x)
|
399 |
+
return out
|
400 |
+
|
401 |
+
|
402 |
+
def _normalize_axes(axes: Iterable[int], ndim: int) -> Tuple[int]:
|
403 |
+
# A tuple by convention. len(axes_tuple) then also gives the rank efficiently.
|
404 |
+
return tuple([ax if ax >= 0 else ndim + ax for ax in axes])
|
405 |
+
|
406 |
+
|
407 |
+
def _canonicalize_tuple(x):
|
408 |
+
if isinstance(x, Iterable):
|
409 |
+
return tuple(x)
|
410 |
+
else:
|
411 |
+
return (x,)
|
412 |
+
|
413 |
+
|
414 |
+
# ------------------------------------------------------------------------------
|
415 |
+
# DenseGeneral for attention layers.
|
416 |
+
# ------------------------------------------------------------------------------
|
417 |
+
class DenseGeneral(nn.Module):
|
418 |
+
"""A linear transformation (without bias) with flexible axes.
|
419 |
+
|
420 |
+
Attributes:
|
421 |
+
features: tuple with numbers of output features.
|
422 |
+
axis: tuple with axes to apply the transformation on.
|
423 |
+
dtype: the dtype of the computation (default: float32).
|
424 |
+
kernel_init: initializer function for the weight matrix.
|
425 |
+
"""
|
426 |
+
|
427 |
+
features: Union[Iterable[int], int]
|
428 |
+
axis: Union[Iterable[int], int] = -1
|
429 |
+
dtype: DType = jnp.float32
|
430 |
+
params_dtype: DType = jnp.float32
|
431 |
+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal")
|
432 |
+
kernel_axes: Tuple[str, ...] = ()
|
433 |
+
use_bias: bool = True
|
434 |
+
bias_init: Any = nn.initializers.zeros
|
435 |
+
|
436 |
+
@nn.compact
|
437 |
+
def __call__(self, inputs: Array) -> Array:
|
438 |
+
"""Applies a linear transformation to the inputs along multiple dimensions.
|
439 |
+
|
440 |
+
Args:
|
441 |
+
inputs: The nd-array to be transformed.
|
442 |
+
|
443 |
+
Returns:
|
444 |
+
The transformed input.
|
445 |
+
"""
|
446 |
+
features = _canonicalize_tuple(self.features)
|
447 |
+
axis = _canonicalize_tuple(self.axis)
|
448 |
+
|
449 |
+
inputs = jnp.asarray(inputs, self.dtype)
|
450 |
+
axis = _normalize_axes(axis, inputs.ndim)
|
451 |
+
|
452 |
+
kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + features
|
453 |
+
kernel_in_axis = np.arange(len(axis))
|
454 |
+
kernel_out_axis = np.arange(len(axis), len(axis) + len(features))
|
455 |
+
kernel = param_with_axes(
|
456 |
+
"kernel",
|
457 |
+
self.kernel_init,
|
458 |
+
kernel_shape,
|
459 |
+
self.params_dtype,
|
460 |
+
kernel_in_axis,
|
461 |
+
kernel_out_axis,
|
462 |
+
axes=self.kernel_axes,
|
463 |
+
)
|
464 |
+
if self.use_bias:
|
465 |
+
bias = param_with_axes(
|
466 |
+
"bias",
|
467 |
+
self.bias_init,
|
468 |
+
features,
|
469 |
+
self.params_dtype,
|
470 |
+
axes=(self.kernel_axes[-1],),
|
471 |
+
)
|
472 |
+
kernel = jnp.asarray(kernel, self.dtype)
|
473 |
+
|
474 |
+
contract_ind = tuple(range(0, len(axis)))
|
475 |
+
y = lax.dot_general(inputs, kernel, ((axis, contract_ind), ((), ())))
|
476 |
+
if self.use_bias:
|
477 |
+
bias = jnp.asarray(bias, self.dtype)
|
478 |
+
# y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))
|
479 |
+
y += jnp.reshape(bias, (1,) * (len(features) - y.ndim) + bias.shape[:])
|
480 |
+
return y
|
481 |
+
|
482 |
+
|
483 |
+
def _convert_to_activation_function(fn_or_string: Union[str, Callable]) -> Callable:
|
484 |
+
"""Convert a string to an activation function."""
|
485 |
+
if fn_or_string == "linear":
|
486 |
+
return lambda x: x
|
487 |
+
elif isinstance(fn_or_string, str):
|
488 |
+
return getattr(nn, fn_or_string)
|
489 |
+
elif callable(fn_or_string):
|
490 |
+
return fn_or_string
|
491 |
+
else:
|
492 |
+
raise ValueError("don't know how to convert %s to an activation function" % (fn_or_string,))
|
493 |
+
|
494 |
+
|
495 |
+
class MlpBlock(nn.Module):
|
496 |
+
"""Transformer MLP / feed-forward block.
|
497 |
+
|
498 |
+
Attributes:
|
499 |
+
intermediate_dim: Shared dimension of hidden layers.
|
500 |
+
activations: Type of activations for each layer. Each element is either
|
501 |
+
'linear', a string function name in flax.linen, or a function.
|
502 |
+
kernel_init: Kernel function, passed to the dense layers.
|
503 |
+
deterministic: Whether the dropout layers should be deterministic.
|
504 |
+
intermediate_dropout_rate: Dropout rate used after the intermediate layers.
|
505 |
+
dtype: Type for the dense layer.
|
506 |
+
"""
|
507 |
+
|
508 |
+
intermediate_dim: int = 2048
|
509 |
+
activations: Sequence[Union[str, Callable]] = ("relu",)
|
510 |
+
kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "truncated_normal")
|
511 |
+
intermediate_dropout_rate: float = 0.1
|
512 |
+
dtype: Any = jnp.float32
|
513 |
+
|
514 |
+
@nn.compact
|
515 |
+
def __call__(self, inputs, decode: bool = False, deterministic: bool = False):
|
516 |
+
"""Applies Transformer MlpBlock module."""
|
517 |
+
# Iterate over specified MLP input activation functions.
|
518 |
+
# e.g. ('relu',) or ('gelu', 'linear') for gated-gelu.
|
519 |
+
activations = []
|
520 |
+
for idx, act_fn in enumerate(self.activations):
|
521 |
+
dense_name = "wi" if len(self.activations) == 1 else f"wi_{idx}"
|
522 |
+
x = DenseGeneral(
|
523 |
+
self.intermediate_dim,
|
524 |
+
dtype=self.dtype,
|
525 |
+
kernel_init=self.kernel_init,
|
526 |
+
kernel_axes=("embed", "mlp"),
|
527 |
+
name=dense_name,
|
528 |
+
)(inputs)
|
529 |
+
x = _convert_to_activation_function(act_fn)(x)
|
530 |
+
activations.append(x)
|
531 |
+
|
532 |
+
# Take elementwise product of above intermediate activations.
|
533 |
+
x = functools.reduce(operator.mul, activations)
|
534 |
+
# Apply dropout and final dense output projection.
|
535 |
+
x = nn.Dropout(rate=self.intermediate_dropout_rate, broadcast_dims=(-2,))(
|
536 |
+
x, deterministic=deterministic
|
537 |
+
) # Broadcast along length.
|
538 |
+
x = with_sharding_constraint(x, ("batch", "length", "mlp"))
|
539 |
+
output = DenseGeneral(
|
540 |
+
inputs.shape[-1],
|
541 |
+
dtype=self.dtype,
|
542 |
+
kernel_init=self.kernel_init,
|
543 |
+
kernel_axes=("mlp", "embed"),
|
544 |
+
name="wo",
|
545 |
+
)(x)
|
546 |
+
return output
|
547 |
+
|
548 |
+
|
549 |
+
class Embed(nn.Module):
|
550 |
+
"""A parameterized function from integers [0, n) to d-dimensional vectors.
|
551 |
+
|
552 |
+
Attributes:
|
553 |
+
num_embeddings: number of embeddings.
|
554 |
+
features: number of feature dimensions for each embedding.
|
555 |
+
dtype: the dtype of the embedding vectors (default: float32).
|
556 |
+
embedding_init: embedding initializer.
|
557 |
+
one_hot: performs the gather with a one-hot contraction rather than a true
|
558 |
+
gather. This is currently needed for SPMD partitioning.
|
559 |
+
"""
|
560 |
+
|
561 |
+
num_embeddings: int
|
562 |
+
features: int
|
563 |
+
cast_input_dtype: Optional[DType] = None
|
564 |
+
dtype: DType = jnp.float32
|
565 |
+
params_dtype: DType = jnp.float32
|
566 |
+
attend_dtype: Optional[DType] = None
|
567 |
+
embedding_init: Initializer = default_embed_init
|
568 |
+
one_hot: bool = True
|
569 |
+
embedding: Array = dataclasses.field(init=False)
|
570 |
+
|
571 |
+
def setup(self):
|
572 |
+
self.embedding = param_with_axes(
|
573 |
+
"embedding",
|
574 |
+
self.embedding_init,
|
575 |
+
(self.num_embeddings, self.features),
|
576 |
+
self.params_dtype,
|
577 |
+
axes=("vocab", "embed"),
|
578 |
+
)
|
579 |
+
|
580 |
+
def __call__(self, inputs: Array) -> Array:
|
581 |
+
"""Embeds the inputs along the last dimension.
|
582 |
+
|
583 |
+
Args:
|
584 |
+
inputs: input data, all dimensions are considered batch dimensions.
|
585 |
+
|
586 |
+
Returns:
|
587 |
+
Output which is embedded input data. The output shape follows the input,
|
588 |
+
with an additional `features` dimension appended.
|
589 |
+
"""
|
590 |
+
if self.cast_input_dtype:
|
591 |
+
inputs = inputs.astype(self.cast_input_dtype)
|
592 |
+
if not jnp.issubdtype(inputs.dtype, jnp.integer):
|
593 |
+
raise ValueError("Input type must be an integer or unsigned integer.")
|
594 |
+
if self.one_hot:
|
595 |
+
iota = lax.iota(jnp.int32, self.num_embeddings)
|
596 |
+
one_hot = jnp.array(inputs[..., jnp.newaxis] == iota, dtype=self.dtype)
|
597 |
+
output = jnp.dot(one_hot, jnp.asarray(self.embedding, self.dtype))
|
598 |
+
else:
|
599 |
+
output = jnp.asarray(self.embedding, self.dtype)[inputs]
|
600 |
+
output = with_sharding_constraint(output, ("batch", "length", "embed"))
|
601 |
+
return output
|
602 |
+
|
603 |
+
def attend(self, query: Array) -> Array:
|
604 |
+
"""Attend over the embedding using a query array.
|
605 |
+
|
606 |
+
Args:
|
607 |
+
query: array with last dimension equal the feature depth `features` of the
|
608 |
+
embedding.
|
609 |
+
|
610 |
+
Returns:
|
611 |
+
An array with final dim `num_embeddings` corresponding to the batched
|
612 |
+
inner-product of the array of query vectors against each embedding.
|
613 |
+
Commonly used for weight-sharing between embeddings and logit transform
|
614 |
+
in NLP models.
|
615 |
+
"""
|
616 |
+
dtype = self.attend_dtype if self.attend_dtype is not None else self.dtype
|
617 |
+
return jnp.dot(query, jnp.asarray(self.embedding, dtype).T)
|
618 |
+
|
619 |
+
|
620 |
+
class RelativePositionBiases(nn.Module):
|
621 |
+
"""Adds T5-style relative positional embeddings to the attention logits.
|
622 |
+
|
623 |
+
Attributes:
|
624 |
+
num_buckets: Number of buckets to bucket distances between key and query
|
625 |
+
positions into.
|
626 |
+
max_distance: Maximum distance before everything is lumped into the last
|
627 |
+
distance bucket.
|
628 |
+
num_heads: Number of heads in the attention layer. Each head will get a
|
629 |
+
different relative position weighting.
|
630 |
+
dtype: Type of arrays through this module.
|
631 |
+
embedding_init: initializer for relative embedding table.
|
632 |
+
"""
|
633 |
+
|
634 |
+
num_buckets: int
|
635 |
+
max_distance: int
|
636 |
+
num_heads: int
|
637 |
+
dtype: Any
|
638 |
+
embedding_init: Callable[..., Array] = nn.linear.default_embed_init
|
639 |
+
|
640 |
+
@staticmethod
|
641 |
+
def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
|
642 |
+
"""Translate relative position to a bucket number for relative attention.
|
643 |
+
|
644 |
+
The relative position is defined as memory_position - query_position, i.e.
|
645 |
+
the distance in tokens from the attending position to the attended-to
|
646 |
+
position. If bidirectional=False, then positive relative positions are
|
647 |
+
invalid.
|
648 |
+
We use smaller buckets for small absolute relative_position and larger
|
649 |
+
buckets for larger absolute relative_positions. All relative
|
650 |
+
positions >=max_distance map to the same bucket. All relative
|
651 |
+
positions <=-max_distance map to the same bucket. This should allow for
|
652 |
+
more graceful generalization to longer sequences than the model has been
|
653 |
+
trained on.
|
654 |
+
|
655 |
+
Args:
|
656 |
+
relative_position: an int32 array
|
657 |
+
bidirectional: a boolean - whether the attention is bidirectional
|
658 |
+
num_buckets: an integer
|
659 |
+
max_distance: an integer
|
660 |
+
|
661 |
+
Returns:
|
662 |
+
a Tensor with the same shape as relative_position, containing int32
|
663 |
+
values in the range [0, num_buckets)
|
664 |
+
"""
|
665 |
+
ret = 0
|
666 |
+
n = -relative_position
|
667 |
+
if bidirectional:
|
668 |
+
num_buckets //= 2
|
669 |
+
ret += (n < 0).astype(np.int32) * num_buckets
|
670 |
+
n = np.abs(n)
|
671 |
+
else:
|
672 |
+
n = np.maximum(n, 0)
|
673 |
+
# now n is in the range [0, inf)
|
674 |
+
max_exact = num_buckets // 2
|
675 |
+
is_small = n < max_exact
|
676 |
+
val_if_large = max_exact + (
|
677 |
+
np.log(n.astype(np.float32) / max_exact + np.finfo(np.float32).eps)
|
678 |
+
/ np.log(max_distance / max_exact)
|
679 |
+
* (num_buckets - max_exact)
|
680 |
+
).astype(np.int32)
|
681 |
+
val_if_large = np.minimum(val_if_large, num_buckets - 1)
|
682 |
+
ret += np.where(is_small, n, val_if_large)
|
683 |
+
return ret
|
684 |
+
|
685 |
+
@nn.compact
|
686 |
+
def __call__(self, qlen, klen, bidirectional=True):
|
687 |
+
"""Produce relative position embedding attention biases.
|
688 |
+
|
689 |
+
Args:
|
690 |
+
qlen: attention query length.
|
691 |
+
klen: attention key length.
|
692 |
+
bidirectional: whether to allow positive memory-query relative position
|
693 |
+
embeddings.
|
694 |
+
|
695 |
+
Returns:
|
696 |
+
output: `(1, len, q_len, k_len)` attention bias
|
697 |
+
"""
|
698 |
+
# TODO(levskaya): should we be computing this w. numpy as a program
|
699 |
+
# constant?
|
700 |
+
context_position = np.arange(qlen, dtype=jnp.int32)[:, None]
|
701 |
+
memory_position = np.arange(klen, dtype=jnp.int32)[None, :]
|
702 |
+
relative_position = memory_position - context_position # shape (qlen, klen)
|
703 |
+
rp_bucket = self._relative_position_bucket(
|
704 |
+
relative_position,
|
705 |
+
bidirectional=bidirectional,
|
706 |
+
num_buckets=self.num_buckets,
|
707 |
+
max_distance=self.max_distance,
|
708 |
+
)
|
709 |
+
relative_attention_bias = param_with_axes(
|
710 |
+
"rel_embedding",
|
711 |
+
self.embedding_init,
|
712 |
+
(self.num_heads, self.num_buckets),
|
713 |
+
jnp.float32,
|
714 |
+
axes=("heads", "relpos_buckets"),
|
715 |
+
)
|
716 |
+
|
717 |
+
relative_attention_bias = jnp.asarray(relative_attention_bias, self.dtype)
|
718 |
+
# Instead of using a slow gather, we create a leading-dimension one-hot
|
719 |
+
# array from rp_bucket and use it to perform the gather-equivalent via a
|
720 |
+
# contraction, i.e.:
|
721 |
+
# (num_head, num_buckets) x (num_buckets one-hot, qlen, klen).
|
722 |
+
# This is equivalent to relative_attention_bias[:, rp_bucket]
|
723 |
+
bcast_iota = lax.broadcasted_iota(jnp.int32, (self.num_buckets, 1, 1), 0)
|
724 |
+
rp_bucket_one_hot = jnp.array(rp_bucket[jnp.newaxis, ...] == bcast_iota, dtype=self.dtype)
|
725 |
+
# --> shape (qlen, klen, num_heads)
|
726 |
+
values = lax.dot_general(
|
727 |
+
relative_attention_bias,
|
728 |
+
rp_bucket_one_hot,
|
729 |
+
(((1,), (0,)), ((), ())), # rhs, lhs contracting dims
|
730 |
+
) # no batched dims
|
731 |
+
# Add a singleton batch dimension.
|
732 |
+
# --> shape (1, num_heads, qlen, klen)
|
733 |
+
return values[jnp.newaxis, ...]
|
734 |
+
|
735 |
+
|
736 |
+
# ------------------------------------------------------------------------------
|
737 |
+
# T5 Layernorm - no subtraction of mean or bias.
|
738 |
+
# ------------------------------------------------------------------------------
|
739 |
+
# class LayerNorm(nn.Module):
|
740 |
+
# """T5 Layer normalization operating on the last axis of the input data."""
|
741 |
+
# epsilon: float = 1e-6
|
742 |
+
# dtype: Any = jnp.float32
|
743 |
+
# scale_init: Initializer = nn.initializers.ones
|
744 |
+
|
745 |
+
# @nn.compact
|
746 |
+
# def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
747 |
+
# """Applies layer normalization on the input."""
|
748 |
+
# x = jnp.asarray(x, jnp.float32)
|
749 |
+
# features = x.shape[-1]
|
750 |
+
# mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
|
751 |
+
# y = jnp.asarray(x * lax.rsqrt(mean2 + self.epsilon), self.dtype)
|
752 |
+
# scale = param_with_axes(
|
753 |
+
# 'scale', self.scale_init, (features,), jnp.float32, axes=('embed',))
|
754 |
+
|
755 |
+
# scale = jnp.asarray(scale, self.dtype)
|
756 |
+
# return y * scale
|
757 |
+
|
758 |
+
|
759 |
+
class LayerNorm(nn.Module):
|
760 |
+
"""Layer normalization (https://arxiv.org/abs/1607.06450).
|
761 |
+
Operates on the last axis of the input data.
|
762 |
+
It normalizes the activations of the layer for each given example in a
|
763 |
+
batch independently, rather than across a batch like Batch Normalization.
|
764 |
+
i.e. applies a transformation that maintains the mean activation within
|
765 |
+
each example close to 0 and the activation standard deviation close to 1.
|
766 |
+
Attributes:
|
767 |
+
epsilon: A small float added to variance to avoid dividing by zero.
|
768 |
+
dtype: the dtype of the computation (default: float32).
|
769 |
+
use_bias: If True, bias (beta) is added.
|
770 |
+
use_scale: If True, multiply by scale (gamma). When the next layer is linear
|
771 |
+
(also e.g. nn.relu), this can be disabled since the scaling will be done
|
772 |
+
by the next layer.
|
773 |
+
bias_init: Initializer for bias, by default, zero.
|
774 |
+
scale_init: Initializer for scale, by default, one.
|
775 |
+
"""
|
776 |
+
|
777 |
+
epsilon: float = 1e-6
|
778 |
+
dtype: Any = jnp.float32
|
779 |
+
params_dtype: DType = jnp.float32
|
780 |
+
use_bias: bool = True
|
781 |
+
use_scale: bool = True
|
782 |
+
bias_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.zeros
|
783 |
+
scale_init: Callable[[PRNGKey, Shape, Any], Array] = nn.initializers.ones
|
784 |
+
|
785 |
+
@nn.compact
|
786 |
+
def __call__(self, x):
|
787 |
+
"""Applies layer normalization on the input.
|
788 |
+
Args:
|
789 |
+
x: the inputs
|
790 |
+
Returns:
|
791 |
+
Normalized inputs (the same shape as inputs).
|
792 |
+
"""
|
793 |
+
x = jnp.asarray(x, jnp.float32)
|
794 |
+
features = x.shape[-1]
|
795 |
+
mean = jnp.mean(x, axis=-1, keepdims=True)
|
796 |
+
mean2 = jnp.mean(lax.square(x), axis=-1, keepdims=True)
|
797 |
+
var = mean2 - lax.square(mean)
|
798 |
+
mul = lax.rsqrt(var + self.epsilon)
|
799 |
+
if self.use_scale:
|
800 |
+
scale = param_with_axes(
|
801 |
+
"scale",
|
802 |
+
self.scale_init,
|
803 |
+
(features,),
|
804 |
+
self.params_dtype,
|
805 |
+
axes=("embed",),
|
806 |
+
)
|
807 |
+
mul = mul * jnp.asarray(scale, self.dtype)
|
808 |
+
y = (x - mean) * mul
|
809 |
+
if self.use_bias:
|
810 |
+
bias = param_with_axes("bias", self.bias_init, (features,), self.params_dtype, axes=("embed",))
|
811 |
+
y = y + jnp.asarray(bias, self.dtype)
|
812 |
+
return jnp.asarray(y, self.dtype)
|
813 |
+
|
814 |
+
|
815 |
+
# ------------------------------------------------------------------------------
|
816 |
+
# Mask-making utility functions.
|
817 |
+
# ------------------------------------------------------------------------------
|
818 |
+
def make_attention_mask(
|
819 |
+
query_input: Array,
|
820 |
+
key_input: Array,
|
821 |
+
pairwise_fn: Callable = jnp.multiply,
|
822 |
+
extra_batch_dims: int = 0,
|
823 |
+
dtype: DType = jnp.float32,
|
824 |
+
) -> Array:
|
825 |
+
"""Mask-making helper for attention weights.
|
826 |
+
|
827 |
+
In case of 1d inputs (i.e., `[batch, len_q]`, `[batch, len_kv]`, the
|
828 |
+
attention weights will be `[batch, heads, len_q, len_kv]` and this
|
829 |
+
function will produce `[batch, 1, len_q, len_kv]`.
|
830 |
+
|
831 |
+
Args:
|
832 |
+
query_input: a batched, flat input of query_length size
|
833 |
+
key_input: a batched, flat input of key_length size
|
834 |
+
pairwise_fn: broadcasting elementwise comparison function
|
835 |
+
extra_batch_dims: number of extra batch dims to add singleton axes for, none
|
836 |
+
by default
|
837 |
+
dtype: mask return dtype
|
838 |
+
|
839 |
+
Returns:
|
840 |
+
A `[batch, 1, len_q, len_kv]` shaped mask for 1d attention.
|
841 |
+
"""
|
842 |
+
# [batch, len_q, len_kv]
|
843 |
+
mask = pairwise_fn(
|
844 |
+
# [batch, len_q] -> [batch, len_q, 1]
|
845 |
+
jnp.expand_dims(query_input, axis=-1),
|
846 |
+
# [batch, len_q] -> [batch, 1, len_kv]
|
847 |
+
jnp.expand_dims(key_input, axis=-2),
|
848 |
+
)
|
849 |
+
|
850 |
+
# [batch, 1, len_q, len_kv]. This creates the head dim.
|
851 |
+
mask = jnp.expand_dims(mask, axis=-3)
|
852 |
+
mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
|
853 |
+
return mask.astype(dtype)
|
854 |
+
|
855 |
+
|
856 |
+
def make_causal_mask(x: Array, extra_batch_dims: int = 0, dtype: DType = jnp.float32) -> Array:
|
857 |
+
"""Make a causal mask for self-attention.
|
858 |
+
|
859 |
+
In case of 1d inputs (i.e., `[batch, len]`, the self-attention weights
|
860 |
+
will be `[batch, heads, len, len]` and this function will produce a
|
861 |
+
causal mask of shape `[batch, 1, len, len]`.
|
862 |
+
|
863 |
+
Note that a causal mask does not depend on the values of x; it only depends on
|
864 |
+
the shape. If x has padding elements, they will not be treated in a special
|
865 |
+
manner.
|
866 |
+
|
867 |
+
Args:
|
868 |
+
x: input array of shape `[batch, len]`
|
869 |
+
extra_batch_dims: number of batch dims to add singleton axes for, none by
|
870 |
+
default
|
871 |
+
dtype: mask return dtype
|
872 |
+
|
873 |
+
Returns:
|
874 |
+
A `[batch, 1, len, len]` shaped causal mask for 1d attention.
|
875 |
+
"""
|
876 |
+
idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
|
877 |
+
return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype)
|
878 |
+
|
879 |
+
|
880 |
+
def combine_masks(*masks: Optional[Array], dtype: DType = jnp.float32):
|
881 |
+
"""Combine attention masks.
|
882 |
+
|
883 |
+
Args:
|
884 |
+
*masks: set of attention mask arguments to combine, some can be None.
|
885 |
+
dtype: final mask dtype
|
886 |
+
|
887 |
+
Returns:
|
888 |
+
Combined mask, reduced by logical and, returns None if no masks given.
|
889 |
+
"""
|
890 |
+
masks = [m for m in masks if m is not None]
|
891 |
+
if not masks:
|
892 |
+
return None
|
893 |
+
assert all(
|
894 |
+
(x.ndim == masks[0].ndim for x in masks)
|
895 |
+
), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
|
896 |
+
mask, *other_masks = masks
|
897 |
+
for other_mask in other_masks:
|
898 |
+
mask = jnp.logical_and(mask, other_mask)
|
899 |
+
return mask.astype(dtype)
|
900 |
+
|
901 |
+
|
902 |
+
def combine_biases(*masks: Optional[Array]):
|
903 |
+
"""Combine attention biases.
|
904 |
+
|
905 |
+
Args:
|
906 |
+
*masks: set of attention bias arguments to combine, some can be None.
|
907 |
+
|
908 |
+
Returns:
|
909 |
+
Combined mask, reduced by summation, returns None if no masks given.
|
910 |
+
"""
|
911 |
+
masks = [m for m in masks if m is not None]
|
912 |
+
if not masks:
|
913 |
+
return None
|
914 |
+
assert all(
|
915 |
+
(x.ndim == masks[0].ndim for x in masks)
|
916 |
+
), f"masks must have same rank: {tuple((x.ndim for x in masks))}"
|
917 |
+
mask, *other_masks = masks
|
918 |
+
for other_mask in other_masks:
|
919 |
+
mask = mask + other_mask
|
920 |
+
return mask
|
921 |
+
|
922 |
+
|
923 |
+
def make_decoder_mask(
|
924 |
+
decoder_target_tokens: Array,
|
925 |
+
dtype: DType,
|
926 |
+
decoder_causal_attention: Optional[Array] = None,
|
927 |
+
decoder_segment_ids: Optional[Array] = None,
|
928 |
+
) -> Array:
|
929 |
+
"""Compute the self-attention mask for a decoder.
|
930 |
+
|
931 |
+
Decoder mask is formed by combining a causal mask, a padding mask and an
|
932 |
+
optional packing mask. If decoder_causal_attention is passed, it makes the
|
933 |
+
masking non-causal for positions that have value of 1.
|
934 |
+
|
935 |
+
A prefix LM is applied to a dataset which has a notion of "inputs" and
|
936 |
+
"targets", e.g., a machine translation task. The inputs and targets are
|
937 |
+
concatenated to form a new target. `decoder_target_tokens` is the concatenated
|
938 |
+
decoder output tokens.
|
939 |
+
|
940 |
+
The "inputs" portion of the concatenated sequence can attend to other "inputs"
|
941 |
+
tokens even for those at a later time steps. In order to control this
|
942 |
+
behavior, `decoder_causal_attention` is necessary. This is a binary mask with
|
943 |
+
a value of 1 indicating that the position belonged to "inputs" portion of the
|
944 |
+
original dataset.
|
945 |
+
|
946 |
+
Example:
|
947 |
+
|
948 |
+
Suppose we have a dataset with two examples.
|
949 |
+
|
950 |
+
ds = [{"inputs": [6, 7], "targets": [8]},
|
951 |
+
{"inputs": [3, 4], "targets": [5]}]
|
952 |
+
|
953 |
+
After the data preprocessing with packing, the two examples are packed into
|
954 |
+
one example with the following three fields (some fields are skipped for
|
955 |
+
simplicity).
|
956 |
+
|
957 |
+
decoder_target_tokens = [[6, 7, 8, 3, 4, 5, 0]]
|
958 |
+
decoder_segment_ids = [[1, 1, 1, 2, 2, 2, 0]]
|
959 |
+
decoder_causal_attention = [[1, 1, 0, 1, 1, 0, 0]]
|
960 |
+
|
961 |
+
where each array has [batch, length] shape with batch size being 1. Then,
|
962 |
+
this function computes the following mask.
|
963 |
+
|
964 |
+
mask = [[[[1, 1, 0, 0, 0, 0, 0],
|
965 |
+
[1, 1, 0, 0, 0, 0, 0],
|
966 |
+
[1, 1, 1, 0, 0, 0, 0],
|
967 |
+
[0, 0, 0, 1, 1, 0, 0],
|
968 |
+
[0, 0, 0, 1, 1, 0, 0],
|
969 |
+
[0, 0, 0, 1, 1, 1, 0],
|
970 |
+
[0, 0, 0, 0, 0, 0, 0]]]]
|
971 |
+
|
972 |
+
mask[b, 1, :, :] represents the mask for the example `b` in the batch.
|
973 |
+
Because mask is for a self-attention layer, the mask's shape is a square of
|
974 |
+
shape [query length, key length].
|
975 |
+
|
976 |
+
mask[b, 1, i, j] = 1 means that the query token at position i can attend to
|
977 |
+
the key token at position j.
|
978 |
+
|
979 |
+
Args:
|
980 |
+
decoder_target_tokens: decoder output tokens. [batch, length]
|
981 |
+
dtype: dtype of the output mask.
|
982 |
+
decoder_causal_attention: a binary mask indicating which position should
|
983 |
+
only attend to earlier positions in the sequence. Others will attend
|
984 |
+
bidirectionally. [batch, length]
|
985 |
+
decoder_segment_ids: decoder segmentation info for packed examples. [batch,
|
986 |
+
length]
|
987 |
+
|
988 |
+
Returns:
|
989 |
+
the combined decoder mask.
|
990 |
+
"""
|
991 |
+
masks = []
|
992 |
+
# The same mask is applied to all attention heads. So the head dimension is 1,
|
993 |
+
# i.e., the mask will be broadcast along the heads dim.
|
994 |
+
# [batch, 1, length, length]
|
995 |
+
causal_mask = make_causal_mask(decoder_target_tokens, dtype=dtype)
|
996 |
+
|
997 |
+
# Positions with value 1 in `decoder_causal_attneition` can attend
|
998 |
+
# bidirectionally.
|
999 |
+
if decoder_causal_attention is not None:
|
1000 |
+
# [batch, 1, length, length]
|
1001 |
+
inputs_mask = make_attention_mask(
|
1002 |
+
decoder_causal_attention,
|
1003 |
+
decoder_causal_attention,
|
1004 |
+
jnp.logical_and,
|
1005 |
+
dtype=dtype,
|
1006 |
+
)
|
1007 |
+
masks.append(jnp.logical_or(causal_mask, inputs_mask).astype(dtype))
|
1008 |
+
else:
|
1009 |
+
masks.append(causal_mask)
|
1010 |
+
|
1011 |
+
# Padding mask.
|
1012 |
+
masks.append(make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype))
|
1013 |
+
|
1014 |
+
# Packing mask
|
1015 |
+
if decoder_segment_ids is not None:
|
1016 |
+
masks.append(make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype))
|
1017 |
+
|
1018 |
+
return combine_masks(*masks, dtype=dtype)
|
1019 |
+
|
1020 |
+
|
1021 |
+
def canonicalize_padding(padding: PaddingLike, rank: int) -> LaxPadding:
|
1022 |
+
""" "Canonicalizes conv padding to a jax.lax supported format."""
|
1023 |
+
if isinstance(padding, str):
|
1024 |
+
return padding
|
1025 |
+
if isinstance(padding, int):
|
1026 |
+
return [(padding, padding)] * rank
|
1027 |
+
if isinstance(padding, Sequence) and len(padding) == rank:
|
1028 |
+
new_pad = []
|
1029 |
+
for p in padding:
|
1030 |
+
if isinstance(p, int):
|
1031 |
+
new_pad.append((p, p))
|
1032 |
+
elif isinstance(p, tuple) and len(p) == 2:
|
1033 |
+
new_pad.append(p)
|
1034 |
+
else:
|
1035 |
+
break
|
1036 |
+
if len(new_pad) == rank:
|
1037 |
+
return new_pad
|
1038 |
+
raise ValueError(
|
1039 |
+
f"Invalid padding format: {padding}, should be str, int,"
|
1040 |
+
f" or a sequence of len {rank} where each element is an"
|
1041 |
+
" int or pair of ints."
|
1042 |
+
)
|
1043 |
+
|
1044 |
+
|
1045 |
+
def _conv_dimension_numbers(input_shape):
|
1046 |
+
"""Computes the dimension numbers based on the input shape."""
|
1047 |
+
ndim = len(input_shape)
|
1048 |
+
lhs_spec = (0, ndim - 1) + tuple(range(1, ndim - 1))
|
1049 |
+
rhs_spec = (ndim - 1, ndim - 2) + tuple(range(0, ndim - 2))
|
1050 |
+
out_spec = lhs_spec
|
1051 |
+
return lax.ConvDimensionNumbers(lhs_spec, rhs_spec, out_spec)
|
1052 |
+
|
1053 |
+
|
1054 |
+
class _Conv(nn.Module):
|
1055 |
+
"""Convolution Module wrapping `lax.conv_general_dilated[_local]`.
|
1056 |
+
|
1057 |
+
Attributes:
|
1058 |
+
features: number of convolution filters.
|
1059 |
+
kernel_size: shape of the convolutional kernel. For 1D convolution,
|
1060 |
+
the kernel size can be passed as an integer. For all other cases, it must
|
1061 |
+
be a sequence of integers.
|
1062 |
+
strides: an integer or a sequence of `n` integers, representing the
|
1063 |
+
inter-window strides (default: 1).
|
1064 |
+
padding: either the string `'SAME'`, the string `'VALID'`, the string
|
1065 |
+
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
|
1066 |
+
high)` integer pairs that give the padding to apply before and after each
|
1067 |
+
spatial dimension. A single int is interpeted as applying the same padding
|
1068 |
+
in all dims and passign a single int in a sequence causes the same padding
|
1069 |
+
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
|
1070 |
+
left-pad the convolution axis, resulting in same-sized output.
|
1071 |
+
input_dilation: an integer or a sequence of `n` integers, giving the
|
1072 |
+
dilation factor to apply in each spatial dimension of `inputs`
|
1073 |
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
1074 |
+
transposed convolution with stride `d`.
|
1075 |
+
kernel_dilation: an integer or a sequence of `n` integers, giving the
|
1076 |
+
dilation factor to apply in each spatial dimension of the convolution
|
1077 |
+
kernel (default: 1). Convolution with kernel dilation
|
1078 |
+
is also known as 'atrous convolution'.
|
1079 |
+
feature_group_count: integer, default 1. If specified divides the input
|
1080 |
+
features into groups.
|
1081 |
+
use_bias: whether to add a bias to the output (default: True).
|
1082 |
+
mask: Optional mask for the weights during masked convolution. The mask must
|
1083 |
+
be the same shape as the convolution weight matrix.
|
1084 |
+
dtype: the dtype of the computation (default: infer from input and params).
|
1085 |
+
params_dtype: the dtype passed to parameter initializers (default: float32).
|
1086 |
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
1087 |
+
for details.
|
1088 |
+
kernel_init: initializer for the convolutional kernel.
|
1089 |
+
bias_init: initializer for the bias.
|
1090 |
+
"""
|
1091 |
+
|
1092 |
+
features: int
|
1093 |
+
kernel_size: Sequence[int]
|
1094 |
+
strides: Union[None, int, Sequence[int]] = 1
|
1095 |
+
padding: PaddingLike = "SAME"
|
1096 |
+
input_dilation: Union[None, int, Sequence[int]] = 1
|
1097 |
+
kernel_dilation: Union[None, int, Sequence[int]] = 1
|
1098 |
+
feature_group_count: int = 1
|
1099 |
+
use_bias: bool = True
|
1100 |
+
mask: Optional[Array] = None
|
1101 |
+
dtype: Optional[DType] = None
|
1102 |
+
params_dtype: DType = jnp.float32
|
1103 |
+
precision: PrecisionLike = None
|
1104 |
+
kernel_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.lecun_normal()
|
1105 |
+
bias_init: Callable[[PRNGKey, Shape, DType], Array] = nn.initializers.zeros
|
1106 |
+
conv_general_dilated: ConvGeneralDilatedT = lax.conv_general_dilated
|
1107 |
+
kernel_axes: Tuple[str, ...] = ()
|
1108 |
+
|
1109 |
+
@property
|
1110 |
+
def shared_weights(self) -> bool: # type: ignore
|
1111 |
+
"""Defines whether weights are shared or not between different pixels.
|
1112 |
+
|
1113 |
+
Returns:
|
1114 |
+
`True` to use shared weights in convolution (regular convolution).
|
1115 |
+
`False` to use different weights at different pixels, a.k.a.
|
1116 |
+
"locally connected layer", "unshared convolution", or "local convolution".
|
1117 |
+
|
1118 |
+
"""
|
1119 |
+
...
|
1120 |
+
|
1121 |
+
@nn.compact
|
1122 |
+
def __call__(self, inputs: Array) -> Array:
|
1123 |
+
"""Applies a (potentially unshared) convolution to the inputs.
|
1124 |
+
|
1125 |
+
Args:
|
1126 |
+
inputs: input data with dimensions (*batch_dims, spatial_dims...,
|
1127 |
+
features). This is the channels-last convention, i.e. NHWC for a 2d
|
1128 |
+
convolution and NDHWC for a 3D convolution. Note: this is different from
|
1129 |
+
the input convention used by `lax.conv_general_dilated`, which puts the
|
1130 |
+
spatial dimensions last.
|
1131 |
+
Note: If the input has more than 1 batch dimension, all batch dimensions
|
1132 |
+
are flattened into a single dimension for the convolution and restored
|
1133 |
+
before returning. In some cases directly vmap'ing the layer may yield
|
1134 |
+
better performance than this default flattening approach. If the input
|
1135 |
+
lacks a batch dimension it will be added for the convolution and removed
|
1136 |
+
n return, an allowance made to enable writing single-example code.
|
1137 |
+
|
1138 |
+
Returns:
|
1139 |
+
The convolved data.
|
1140 |
+
"""
|
1141 |
+
|
1142 |
+
if isinstance(self.kernel_size, int):
|
1143 |
+
raise TypeError(
|
1144 |
+
"Expected Conv kernel_size to be a"
|
1145 |
+
" tuple/list of integers (eg.: [3, 3]) but got"
|
1146 |
+
f" {self.kernel_size}."
|
1147 |
+
)
|
1148 |
+
else:
|
1149 |
+
kernel_size = tuple(self.kernel_size)
|
1150 |
+
|
1151 |
+
def maybe_broadcast(x: Optional[Union[int, Sequence[int]]]) -> Tuple[int, ...]:
|
1152 |
+
if x is None:
|
1153 |
+
# backward compatibility with using None as sentinel for
|
1154 |
+
# broadcast 1
|
1155 |
+
x = 1
|
1156 |
+
if isinstance(x, int):
|
1157 |
+
return (x,) * len(kernel_size)
|
1158 |
+
return tuple(x)
|
1159 |
+
|
1160 |
+
# Combine all input batch dimensions into a single leading batch axis.
|
1161 |
+
num_batch_dimensions = inputs.ndim - (len(kernel_size) + 1)
|
1162 |
+
if num_batch_dimensions != 1:
|
1163 |
+
input_batch_shape = inputs.shape[:num_batch_dimensions]
|
1164 |
+
total_batch_size = int(np.prod(input_batch_shape))
|
1165 |
+
flat_input_shape = (total_batch_size,) + inputs.shape[num_batch_dimensions:]
|
1166 |
+
inputs = jnp.reshape(inputs, flat_input_shape)
|
1167 |
+
|
1168 |
+
# self.strides or (1,) * (inputs.ndim - 2)
|
1169 |
+
strides = maybe_broadcast(self.strides)
|
1170 |
+
input_dilation = maybe_broadcast(self.input_dilation)
|
1171 |
+
kernel_dilation = maybe_broadcast(self.kernel_dilation)
|
1172 |
+
|
1173 |
+
padding_lax = canonicalize_padding(self.padding, len(kernel_size))
|
1174 |
+
if padding_lax == "CIRCULAR":
|
1175 |
+
kernel_size_dilated = [(k - 1) * d + 1 for k, d in zip(kernel_size, kernel_dilation)]
|
1176 |
+
zero_pad: List[Tuple[int, int]] = [(0, 0)]
|
1177 |
+
pads = zero_pad + [((k - 1) // 2, k // 2) for k in kernel_size_dilated] + [(0, 0)]
|
1178 |
+
inputs = jnp.pad(inputs, pads, mode="wrap")
|
1179 |
+
padding_lax = "VALID"
|
1180 |
+
elif padding_lax == "CAUSAL":
|
1181 |
+
if len(kernel_size) != 1:
|
1182 |
+
raise ValueError("Causal padding is only implemented for 1D convolutions.")
|
1183 |
+
left_pad = kernel_dilation[0] * (kernel_size[0] - 1)
|
1184 |
+
pads = [(0, 0), (left_pad, 0), (0, 0)]
|
1185 |
+
inputs = jnp.pad(inputs, pads)
|
1186 |
+
padding_lax = "VALID"
|
1187 |
+
|
1188 |
+
dimension_numbers = _conv_dimension_numbers(inputs.shape)
|
1189 |
+
in_features = jnp.shape(inputs)[-1]
|
1190 |
+
|
1191 |
+
if self.shared_weights:
|
1192 |
+
# One shared convolutional kernel for all pixels in the output.
|
1193 |
+
assert in_features % self.feature_group_count == 0
|
1194 |
+
kernel_shape = kernel_size + (
|
1195 |
+
in_features // self.feature_group_count,
|
1196 |
+
self.features,
|
1197 |
+
)
|
1198 |
+
|
1199 |
+
else:
|
1200 |
+
if self.feature_group_count != 1:
|
1201 |
+
raise NotImplementedError(
|
1202 |
+
"`lax.conv_general_dilated_local` does not support "
|
1203 |
+
f"`feature_group_count != 1`, got `{self.feature_group_count}`."
|
1204 |
+
)
|
1205 |
+
|
1206 |
+
# Need to know the spatial output shape of a standard convolution to
|
1207 |
+
# create the unshared convolution kernel.
|
1208 |
+
conv_output_shape = jax.eval_shape(
|
1209 |
+
lambda lhs, rhs: self.conv_general_dilated( # pylint: disable=g-long-lambda
|
1210 |
+
lhs=lhs,
|
1211 |
+
rhs=rhs,
|
1212 |
+
window_strides=strides,
|
1213 |
+
padding=padding_lax,
|
1214 |
+
dimension_numbers=dimension_numbers,
|
1215 |
+
lhs_dilation=input_dilation,
|
1216 |
+
rhs_dilation=kernel_dilation,
|
1217 |
+
),
|
1218 |
+
inputs,
|
1219 |
+
jax.ShapedArray(kernel_size + (in_features, self.features), inputs.dtype),
|
1220 |
+
).shape
|
1221 |
+
|
1222 |
+
# One (unshared) convolutional kernel per each pixel in the output.
|
1223 |
+
kernel_shape = conv_output_shape[1:-1] + (
|
1224 |
+
np.prod(kernel_size) * in_features,
|
1225 |
+
self.features,
|
1226 |
+
)
|
1227 |
+
|
1228 |
+
if self.mask is not None and self.mask.shape != kernel_shape:
|
1229 |
+
raise ValueError(
|
1230 |
+
"Mask needs to have the same shape as weights. " f"Shapes are: {self.mask.shape}, {kernel_shape}"
|
1231 |
+
)
|
1232 |
+
|
1233 |
+
kernel = param_with_axes(
|
1234 |
+
"kernel",
|
1235 |
+
self.kernel_init,
|
1236 |
+
kernel_shape,
|
1237 |
+
self.params_dtype,
|
1238 |
+
axes=self.kernel_axes,
|
1239 |
+
)
|
1240 |
+
|
1241 |
+
if self.mask is not None:
|
1242 |
+
kernel *= self.mask
|
1243 |
+
|
1244 |
+
if self.use_bias:
|
1245 |
+
if self.shared_weights:
|
1246 |
+
# One bias weight per output channel, shared between pixels.
|
1247 |
+
bias_shape = (self.features,)
|
1248 |
+
else:
|
1249 |
+
# One bias weight per output entry, unshared betwen pixels.
|
1250 |
+
bias_shape = conv_output_shape[1:]
|
1251 |
+
|
1252 |
+
bias = param_with_axes(
|
1253 |
+
"bias",
|
1254 |
+
self.bias_init,
|
1255 |
+
bias_shape,
|
1256 |
+
self.params_dtype,
|
1257 |
+
axes=(self.kernel_axes[-1],),
|
1258 |
+
)
|
1259 |
+
else:
|
1260 |
+
bias = None
|
1261 |
+
|
1262 |
+
inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype)
|
1263 |
+
if self.shared_weights:
|
1264 |
+
y = self.conv_general_dilated(
|
1265 |
+
inputs,
|
1266 |
+
kernel,
|
1267 |
+
strides,
|
1268 |
+
padding_lax,
|
1269 |
+
lhs_dilation=input_dilation,
|
1270 |
+
rhs_dilation=kernel_dilation,
|
1271 |
+
dimension_numbers=dimension_numbers,
|
1272 |
+
feature_group_count=self.feature_group_count,
|
1273 |
+
precision=self.precision,
|
1274 |
+
)
|
1275 |
+
else:
|
1276 |
+
y = lax.conv_general_dilated_local(
|
1277 |
+
lhs=inputs,
|
1278 |
+
rhs=kernel,
|
1279 |
+
window_strides=strides,
|
1280 |
+
padding=padding_lax,
|
1281 |
+
filter_shape=kernel_size,
|
1282 |
+
lhs_dilation=input_dilation,
|
1283 |
+
rhs_dilation=kernel_dilation,
|
1284 |
+
dimension_numbers=dimension_numbers,
|
1285 |
+
precision=self.precision,
|
1286 |
+
)
|
1287 |
+
|
1288 |
+
if self.use_bias:
|
1289 |
+
bias = bias.reshape((1,) * (y.ndim - bias.ndim) + bias.shape)
|
1290 |
+
y += bias
|
1291 |
+
|
1292 |
+
if num_batch_dimensions != 1:
|
1293 |
+
output_shape = input_batch_shape + y.shape[1:]
|
1294 |
+
y = jnp.reshape(y, output_shape)
|
1295 |
+
return y
|
1296 |
+
|
1297 |
+
|
1298 |
+
class Conv(_Conv):
|
1299 |
+
"""Convolution Module wrapping `lax.conv_general_dilated`.
|
1300 |
+
|
1301 |
+
Attributes:
|
1302 |
+
features: number of convolution filters.
|
1303 |
+
kernel_size: shape of the convolutional kernel. For 1D convolution,
|
1304 |
+
the kernel size can be passed as an integer. For all other cases, it must
|
1305 |
+
be a sequence of integers.
|
1306 |
+
strides: an integer or a sequence of `n` integers, representing the
|
1307 |
+
inter-window strides (default: 1).
|
1308 |
+
padding: either the string `'SAME'`, the string `'VALID'`, the string
|
1309 |
+
`'CIRCULAR'` (periodic boundary conditions), or a sequence of `n` `(low,
|
1310 |
+
high)` integer pairs that give the padding to apply before and after each
|
1311 |
+
spatial dimension. A single int is interpeted as applying the same padding
|
1312 |
+
in all dims and passign a single int in a sequence causes the same padding
|
1313 |
+
to be used on both sides. `'CAUSAL'` padding for a 1D convolution will
|
1314 |
+
left-pad the convolution axis, resulting in same-sized output.
|
1315 |
+
input_dilation: an integer or a sequence of `n` integers, giving the
|
1316 |
+
dilation factor to apply in each spatial dimension of `inputs`
|
1317 |
+
(default: 1). Convolution with input dilation `d` is equivalent to
|
1318 |
+
transposed convolution with stride `d`.
|
1319 |
+
kernel_dilation: an integer or a sequence of `n` integers, giving the
|
1320 |
+
dilation factor to apply in each spatial dimension of the convolution
|
1321 |
+
kernel (default: 1). Convolution with kernel dilation
|
1322 |
+
is also known as 'atrous convolution'.
|
1323 |
+
feature_group_count: integer, default 1. If specified divides the input
|
1324 |
+
features into groups.
|
1325 |
+
use_bias: whether to add a bias to the output (default: True).
|
1326 |
+
mask: Optional mask for the weights during masked convolution. The mask must
|
1327 |
+
be the same shape as the convolution weight matrix.
|
1328 |
+
dtype: the dtype of the computation (default: infer from input and params).
|
1329 |
+
params_dtype: the dtype passed to parameter initializers (default: float32).
|
1330 |
+
precision: numerical precision of the computation see `jax.lax.Precision`
|
1331 |
+
for details.
|
1332 |
+
kernel_init: initializer for the convolutional kernel.
|
1333 |
+
bias_init: initializer for the bias.
|
1334 |
+
"""
|
1335 |
+
|
1336 |
+
@property
|
1337 |
+
def shared_weights(self) -> bool:
|
1338 |
+
return True
|
distil_whisper/modeling_flax_whisper.py
ADDED
@@ -0,0 +1,2136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The OpenAI Authors and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Flax whisper model."""
|
16 |
+
|
17 |
+
import random
|
18 |
+
from functools import partial
|
19 |
+
from typing import Dict, Optional, Tuple, Union
|
20 |
+
|
21 |
+
import flax.linen as nn
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
25 |
+
from flax.linen import combine_masks, make_causal_mask
|
26 |
+
from flax.linen.attention import dot_product_attention_weights
|
27 |
+
from flax.linen.partitioning import remat, scan_with_axes
|
28 |
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
29 |
+
from jax import lax
|
30 |
+
from jax.random import PRNGKey
|
31 |
+
from transformers import WhisperConfig
|
32 |
+
from transformers.generation.flax_logits_process import (
|
33 |
+
FlaxLogitsProcessor,
|
34 |
+
FlaxLogitsProcessorList,
|
35 |
+
FlaxWhisperTimeStampLogitsProcessor,
|
36 |
+
)
|
37 |
+
from transformers.modeling_flax_outputs import (
|
38 |
+
FlaxBaseModelOutput,
|
39 |
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
40 |
+
FlaxCausalLMOutputWithCrossAttentions,
|
41 |
+
FlaxSeq2SeqLMOutput,
|
42 |
+
FlaxSeq2SeqModelOutput,
|
43 |
+
)
|
44 |
+
from transformers.modeling_flax_utils import (
|
45 |
+
ACT2FN,
|
46 |
+
FlaxPreTrainedModel,
|
47 |
+
append_call_sample_docstring,
|
48 |
+
append_replace_return_docstrings,
|
49 |
+
overwrite_call_docstring,
|
50 |
+
)
|
51 |
+
from transformers.utils import (
|
52 |
+
add_start_docstrings,
|
53 |
+
add_start_docstrings_to_model_forward,
|
54 |
+
logging,
|
55 |
+
replace_return_docstrings,
|
56 |
+
)
|
57 |
+
|
58 |
+
from .layers import Conv, DenseGeneral, Embed, LayerNorm, with_sharding_constraint
|
59 |
+
|
60 |
+
|
61 |
+
logger = logging.get_logger(__name__)
|
62 |
+
|
63 |
+
|
64 |
+
_CHECKPOINT_FOR_DOC = "openai/whisper-tiny"
|
65 |
+
_CONFIG_FOR_DOC = "WhisperConfig"
|
66 |
+
|
67 |
+
|
68 |
+
WHISPER_START_DOCSTRING = r"""
|
69 |
+
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
|
70 |
+
library implements for all its models (such as downloading or saving, resizing the input embeddings, pruning heads
|
71 |
+
etc.) This model is also a Flax Linen
|
72 |
+
[flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
|
73 |
+
regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
|
74 |
+
Finally, this model supports inherent JAX features such as:
|
75 |
+
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
|
76 |
+
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
77 |
+
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
|
78 |
+
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
|
79 |
+
|
80 |
+
Parameters:
|
81 |
+
config ([`WhisperConfig`]): Model configuration class with all the parameters of the model.
|
82 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
83 |
+
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
|
84 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
85 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
86 |
+
`jax.numpy.bfloat16` (on TPUs). This can be used to enable mixed-precision training or half-precision
|
87 |
+
inference on GPUs or TPUs. If specified all the computation will be performed with the given `dtype`.
|
88 |
+
**Note that this only specifies the dtype of the computation and does not influence the dtype of model
|
89 |
+
parameters.** If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`]
|
90 |
+
and [`~FlaxPreTrainedModel.to_bf16`].
|
91 |
+
"""
|
92 |
+
|
93 |
+
WHISPER_INPUTS_DOCSTRING = r"""
|
94 |
+
Args:
|
95 |
+
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
|
96 |
+
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
|
97 |
+
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
98 |
+
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
99 |
+
[`WhisperFeatureExtractor`] should be used for extracting the features, padding and conversion into a
|
100 |
+
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`]
|
101 |
+
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
102 |
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
|
103 |
+
is not used. By default the silence in the input log mel spectrogram are ignored.
|
104 |
+
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
105 |
+
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
|
106 |
+
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
107 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids) Whisper uses the `decoder_start_token_id` as
|
108 |
+
the starting token for `decoder_input_ids` generation.
|
109 |
+
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
110 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
111 |
+
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
|
112 |
+
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
113 |
+
position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
114 |
+
Whisper does not use `position_ids` in the encoder as `input_features` is always the same size and doesn't
|
115 |
+
use masking, but this argument is preserved for compatibility. By default the silence in the input log mel
|
116 |
+
spectrogram are ignored.
|
117 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
118 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
119 |
+
range `[0, config.max_position_embeddings - 1]`.
|
120 |
+
output_attentions (`bool`, *optional*):
|
121 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
122 |
+
tensors for more detail.
|
123 |
+
output_hidden_states (`bool`, *optional*):
|
124 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
125 |
+
more detail.
|
126 |
+
return_dict (`bool`, *optional*):
|
127 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
128 |
+
"""
|
129 |
+
|
130 |
+
WHISPER_ENCODE_INPUTS_DOCSTRING = r"""
|
131 |
+
Args:
|
132 |
+
input_features (`numpy.ndarray` of shape `(batch_size, feature_size, sequence_length)`):
|
133 |
+
Float values mel features extracted from the raw speech waveform. Raw speech waveform can be obtained by
|
134 |
+
loading a `.flac` or `.wav` audio file into an array of type `List[float]` or a `numpy.ndarray`, *e.g.* via
|
135 |
+
the soundfile library (`pip install soundfile`). To prepare the array into `input_features`, the
|
136 |
+
[`WhisperFeatureExtractor`] should be used for extracting the mel features, padding and conversion into a
|
137 |
+
tensor of type `numpy.ndarray`. See [`~WhisperFeatureExtractor.__call__`].
|
138 |
+
attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
139 |
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility, but
|
140 |
+
is not used. By default the silence in the input log mel spectrogram are ignored.
|
141 |
+
output_attentions (`bool`, *optional*):
|
142 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
143 |
+
tensors for more detail.
|
144 |
+
output_hidden_states (`bool`, *optional*):
|
145 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
146 |
+
more detail.
|
147 |
+
return_dict (`bool`, *optional*):
|
148 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
149 |
+
"""
|
150 |
+
|
151 |
+
WHISPER_DECODE_INPUTS_DOCSTRING = r"""
|
152 |
+
Args:
|
153 |
+
decoder_input_ids (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`):
|
154 |
+
Indices of decoder input sequence tokens in the vocabulary. Indices can be obtained using
|
155 |
+
[`WhisperTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
|
156 |
+
[What are decoder input IDs?](../glossary#decoder-input-ids)
|
157 |
+
encoder_outputs (`tuple(tuple(numpy.ndarray)`):
|
158 |
+
Tuple consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*: `attentions`)
|
159 |
+
`last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*) is a sequence of
|
160 |
+
hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
|
161 |
+
encoder_attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
162 |
+
Whisper does not support masking of the `input_features`, this argument is preserved for compatibility,
|
163 |
+
but it is not used. By default the silence in the input log mel spectrogram are ignored.
|
164 |
+
decoder_attention_mask (`numpy.ndarray` of shape `(batch_size, target_sequence_length)`, *optional*):
|
165 |
+
Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
|
166 |
+
be used by default. If you want to change padding behavior, you should modify to your needs. See diagram 1
|
167 |
+
in [the paper](https://arxiv.org/abs/1910.13461) for more information on the default strategy.
|
168 |
+
decoder_position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
|
169 |
+
Indices of positions of each decoder input sequence tokens in the position embeddings. Selected in the
|
170 |
+
range `[0, config.max_position_embeddings - 1]`.
|
171 |
+
past_key_values (`Dict[str, numpy.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
|
172 |
+
Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
|
173 |
+
auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
|
174 |
+
output_attentions (`bool`, *optional*):
|
175 |
+
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
|
176 |
+
tensors for more detail.
|
177 |
+
output_hidden_states (`bool`, *optional*):
|
178 |
+
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
|
179 |
+
more detail.
|
180 |
+
return_dict (`bool`, *optional*):
|
181 |
+
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
182 |
+
"""
|
183 |
+
|
184 |
+
|
185 |
+
class FlaxStaticForceTokensLogitsProcessor(FlaxLogitsProcessor):
|
186 |
+
r"""
|
187 |
+
[`FlaxLogitsProcessor`] that takes a list of pairs of integers which indicates a mapping from generation indices to
|
188 |
+
token indices that will be forced before sampling. The processor will set their log probs to 0 and all other tokens
|
189 |
+
to `-inf` so that they are sampled at their corresponding index. This is a static version of the `transformers` logit
|
190 |
+
processor [`FlaxForceTokensLogitsProcessor`] that is compatible with sharded forced tokens.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
force_token_map (`list`):
|
194 |
+
Map giving token ids and indices where they will be forced to be sampled.
|
195 |
+
"""
|
196 |
+
|
197 |
+
def __init__(self, force_token_map):
|
198 |
+
# The generic `transformers` logit processor builds `force_token_array` as a dictionary - this is not a valid
|
199 |
+
# JAX type, and so we switch to using a JAX array instead
|
200 |
+
force_token_map = jnp.array(force_token_map)
|
201 |
+
# Converts the array of format [[index, token]] containing the tokens to be forced to an array, where the
|
202 |
+
# index of the array corresponds to the index of the token to be forced. For XLA compatibility,
|
203 |
+
# indexes without forced tokens will have a negative value. Note that the last token we ever need to force in
|
204 |
+
# Whisper is at position 3, so we only construct an array up to this index. The native version constructs a tensor
|
205 |
+
# dynamically according to the length of the `force_token_map`. Array shapes need to be concrete for XLA compatibility,
|
206 |
+
# so this is not permitted here.
|
207 |
+
force_token_array = jnp.ones(3, dtype=jnp.int32) * -1
|
208 |
+
for index, token in force_token_map:
|
209 |
+
force_token_array = force_token_array.at[index].set(token)
|
210 |
+
self.force_token_array = jnp.int32(force_token_array)
|
211 |
+
|
212 |
+
def __call__(self, input_ids: jnp.ndarray, scores: jnp.ndarray, cur_len: int) -> jnp.ndarray:
|
213 |
+
def _force_token(generation_idx):
|
214 |
+
batch_size = scores.shape[0]
|
215 |
+
current_token = self.force_token_array[generation_idx]
|
216 |
+
|
217 |
+
new_scores = jnp.ones_like(scores, dtype=scores.dtype) * -float("inf")
|
218 |
+
updates = jnp.zeros((batch_size, 1), dtype=scores.dtype)
|
219 |
+
new_scores = lax.dynamic_update_slice(new_scores, updates, (0, current_token))
|
220 |
+
return new_scores
|
221 |
+
|
222 |
+
scores = lax.cond(
|
223 |
+
cur_len >= self.force_token_array.shape[0],
|
224 |
+
# If the current length is geq than the length of force_token_array, the processor does nothing.
|
225 |
+
lambda: scores,
|
226 |
+
# Otherwise, it may force a certain token.
|
227 |
+
lambda: lax.cond(
|
228 |
+
self.force_token_array[cur_len] >= 0,
|
229 |
+
# Only valid (positive) tokens are forced
|
230 |
+
lambda: _force_token(cur_len),
|
231 |
+
# Otherwise, the processor does nothing.
|
232 |
+
lambda: scores,
|
233 |
+
),
|
234 |
+
)
|
235 |
+
return scores
|
236 |
+
|
237 |
+
|
238 |
+
class FlaxWhisperAttention(nn.Module):
|
239 |
+
config: WhisperConfig
|
240 |
+
embed_dim: int
|
241 |
+
num_heads: int
|
242 |
+
dropout: float = 0.0
|
243 |
+
causal: bool = False
|
244 |
+
bias: bool = True
|
245 |
+
dtype: jnp.dtype = jnp.float32
|
246 |
+
params_dtype: jnp.dtype = jnp.float32
|
247 |
+
|
248 |
+
def setup(self) -> None:
|
249 |
+
self.head_dim = self.embed_dim // self.num_heads
|
250 |
+
if self.head_dim * self.num_heads != self.embed_dim:
|
251 |
+
raise ValueError(
|
252 |
+
"embed_dim must be divisible by num_heads (got `embed_dim`:"
|
253 |
+
f" {self.embed_dim} and `num_heads`: {self.num_heads})."
|
254 |
+
)
|
255 |
+
|
256 |
+
dense = partial(
|
257 |
+
DenseGeneral,
|
258 |
+
self.embed_dim,
|
259 |
+
axis=-1,
|
260 |
+
dtype=self.dtype,
|
261 |
+
params_dtype=self.params_dtype,
|
262 |
+
kernel_axes=("embed", "joined_kv"),
|
263 |
+
)
|
264 |
+
|
265 |
+
self.q_proj = dense(use_bias=self.bias)
|
266 |
+
self.k_proj = dense(use_bias=False)
|
267 |
+
self.v_proj = dense(use_bias=self.bias)
|
268 |
+
|
269 |
+
self.out_proj = DenseGeneral(
|
270 |
+
self.embed_dim,
|
271 |
+
axis=-1,
|
272 |
+
dtype=self.dtype,
|
273 |
+
params_dtype=self.params_dtype,
|
274 |
+
kernel_axes=("joined_kv", "embed"),
|
275 |
+
use_bias=self.bias,
|
276 |
+
)
|
277 |
+
|
278 |
+
if self.causal:
|
279 |
+
self.causal_mask = make_causal_mask(
|
280 |
+
jnp.ones((1, self.config.max_target_positions), dtype="bool"),
|
281 |
+
dtype="bool",
|
282 |
+
)
|
283 |
+
|
284 |
+
def __call__(
|
285 |
+
self,
|
286 |
+
hidden_states: jnp.ndarray,
|
287 |
+
key_value_states: Optional[jnp.ndarray] = None,
|
288 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
289 |
+
init_cache: bool = False,
|
290 |
+
deterministic: bool = True,
|
291 |
+
) -> Tuple[jnp.ndarray]:
|
292 |
+
is_cross_attention = key_value_states is not None
|
293 |
+
batch_size = hidden_states.shape[0]
|
294 |
+
|
295 |
+
query_states = self.q_proj(hidden_states)
|
296 |
+
|
297 |
+
if is_cross_attention:
|
298 |
+
key_states = self.k_proj(key_value_states)
|
299 |
+
value_states = self.v_proj(key_value_states)
|
300 |
+
else:
|
301 |
+
key_states = self.k_proj(hidden_states)
|
302 |
+
value_states = self.v_proj(hidden_states)
|
303 |
+
|
304 |
+
query_states = self._split_heads(query_states)
|
305 |
+
key_states = self._split_heads(key_states)
|
306 |
+
value_states = self._split_heads(value_states)
|
307 |
+
|
308 |
+
query_states = with_sharding_constraint(query_states, ("batch", "length", "heads", "kv"))
|
309 |
+
key_states = with_sharding_constraint(key_states, ("batch", "length", "heads", "kv"))
|
310 |
+
value_states = with_sharding_constraint(value_states, ("batch", "length", "heads", "kv"))
|
311 |
+
|
312 |
+
if self.causal:
|
313 |
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
314 |
+
if self.has_variable("cache", "cached_key"):
|
315 |
+
mask_shift = self.variables["cache"]["cache_index"]
|
316 |
+
# max_length of cached_key is last dim
|
317 |
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[-1]
|
318 |
+
causal_mask = lax.dynamic_slice(
|
319 |
+
self.causal_mask,
|
320 |
+
(0, 0, mask_shift, 0),
|
321 |
+
(1, 1, query_length, max_decoder_length),
|
322 |
+
)
|
323 |
+
else:
|
324 |
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
325 |
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
326 |
+
|
327 |
+
# combine masks if needed
|
328 |
+
if attention_mask is not None and self.causal:
|
329 |
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
330 |
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
331 |
+
elif self.causal:
|
332 |
+
attention_mask = causal_mask
|
333 |
+
elif attention_mask is not None:
|
334 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
335 |
+
|
336 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
337 |
+
# and cache the keys and values step by step.
|
338 |
+
|
339 |
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
340 |
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
341 |
+
key_states, value_states, query_states, attention_mask
|
342 |
+
)
|
343 |
+
|
344 |
+
# Convert the boolean attention mask to an attention bias.
|
345 |
+
if attention_mask is not None:
|
346 |
+
# attention mask in the form of attention bias
|
347 |
+
attention_bias = lax.select(
|
348 |
+
attention_mask > 0,
|
349 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
350 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
351 |
+
)
|
352 |
+
else:
|
353 |
+
attention_bias = None
|
354 |
+
|
355 |
+
dropout_rng = None
|
356 |
+
if not deterministic and self.dropout > 0.0:
|
357 |
+
dropout_rng = self.make_rng("dropout")
|
358 |
+
|
359 |
+
attn_weights = dot_product_attention_weights(
|
360 |
+
query_states,
|
361 |
+
key_states,
|
362 |
+
bias=attention_bias,
|
363 |
+
dropout_rng=dropout_rng,
|
364 |
+
dropout_rate=self.dropout,
|
365 |
+
broadcast_dropout=True,
|
366 |
+
deterministic=deterministic,
|
367 |
+
dtype=self.dtype,
|
368 |
+
precision=None,
|
369 |
+
)
|
370 |
+
|
371 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
372 |
+
attn_output = self._merge_heads(attn_output)
|
373 |
+
attn_output = self.out_proj(attn_output)
|
374 |
+
|
375 |
+
return attn_output, attn_weights
|
376 |
+
|
377 |
+
def _split_heads(self, hidden_state) -> jnp.ndarray:
|
378 |
+
return hidden_state.reshape(hidden_state.shape[:2] + (self.num_heads, self.head_dim))
|
379 |
+
|
380 |
+
def _merge_heads(self, hidden_state) -> jnp.ndarray:
|
381 |
+
return hidden_state.reshape(hidden_state.shape[:2] + (self.embed_dim,))
|
382 |
+
|
383 |
+
@nn.compact
|
384 |
+
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
385 |
+
# The following code is largely copied from: https://github.com/google-research/t5x/blob/63d9addf628c6d8c547a407a32095fcb527bb20b/t5x/examples/scalable_t5/layers.py#L280-L284
|
386 |
+
is_initialized = self.has_variable("cache", "cached_key")
|
387 |
+
|
388 |
+
# The key and value have dimension [batch_size, seq_length, num_heads, head_dim],
|
389 |
+
# but we cache them as [batch_size, num_heads, head_dim, seq_length] as a TPU
|
390 |
+
# fusion optimization. This also enables the "scatter via one-hot
|
391 |
+
# broadcast" trick, which means we do a one-hot broadcast instead of a
|
392 |
+
# scatter/gather operations, resulting in a 3-4x speedup in practice.
|
393 |
+
def swap_dims(x):
|
394 |
+
return x[:-3] + tuple(x[i] for i in [-2, -1, -3])
|
395 |
+
|
396 |
+
cached_key = self.variable("cache", "cached_key", jnp.zeros, swap_dims(key.shape), key.dtype)
|
397 |
+
cached_value = self.variable("cache", "cached_value", jnp.zeros, swap_dims(value.shape), value.dtype)
|
398 |
+
cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
|
399 |
+
|
400 |
+
if is_initialized:
|
401 |
+
batch_size, num_heads, head_dim, seq_length = cached_key.value.shape
|
402 |
+
# During fast autoregressive decoding, we feed one position at a time,
|
403 |
+
# and cache the keys and values step by step.
|
404 |
+
# Sanity shape check of cached key against input query.
|
405 |
+
num_updated_cache_vectors = query.shape[1]
|
406 |
+
expected_shape = (batch_size, 1, num_heads, head_dim)
|
407 |
+
if num_updated_cache_vectors == 1 and expected_shape != query.shape:
|
408 |
+
raise ValueError(
|
409 |
+
"Autoregressive cache shape error, expected query shape"
|
410 |
+
f" {expected_shape} instead got {query.shape}"
|
411 |
+
)
|
412 |
+
|
413 |
+
# Create a OHE of the current index. NOTE: the index is increased below.
|
414 |
+
cur_index = cache_index.value
|
415 |
+
|
416 |
+
# In order to update the key, value caches with the current key and
|
417 |
+
# value, we move the seq_length axis to the back, similar to what we did for
|
418 |
+
# the cached ones above.
|
419 |
+
# Note these are currently the key and value of a single position, since
|
420 |
+
# we feed one position at a time.
|
421 |
+
one_token_key = jnp.moveaxis(key, -3, -1)
|
422 |
+
one_token_value = jnp.moveaxis(value, -3, -1)
|
423 |
+
|
424 |
+
# Update key, value caches with our new 1d spatial slices.
|
425 |
+
# We implement an efficient scatter into the cache via one-hot
|
426 |
+
# broadcast and addition.
|
427 |
+
if num_updated_cache_vectors > 1:
|
428 |
+
indices = jnp.eye(num_updated_cache_vectors, seq_length)[None, None]
|
429 |
+
key = cached_key.value + jnp.matmul(one_token_key, indices)
|
430 |
+
value = cached_value.value + jnp.matmul(one_token_value, indices)
|
431 |
+
else:
|
432 |
+
one_hot_indices = jax.nn.one_hot(cur_index, seq_length, dtype=key.dtype)
|
433 |
+
key = cached_key.value + one_token_key * one_hot_indices
|
434 |
+
value = cached_value.value + one_token_value * one_hot_indices
|
435 |
+
|
436 |
+
cached_key.value = key
|
437 |
+
cached_value.value = value
|
438 |
+
cache_index.value = cache_index.value + num_updated_cache_vectors
|
439 |
+
|
440 |
+
# Move the keys and values back to their original shapes.
|
441 |
+
key = jnp.moveaxis(key, -1, -3)
|
442 |
+
value = jnp.moveaxis(value, -1, -3)
|
443 |
+
|
444 |
+
# causal mask for cached decoder self-attention: our single query position should only
|
445 |
+
# attend to those key positions that have already been generated and cached, not the
|
446 |
+
# remaining zero elements.
|
447 |
+
pad_mask = jnp.broadcast_to(
|
448 |
+
jnp.arange(seq_length) < cur_index + num_updated_cache_vectors,
|
449 |
+
(batch_size,) + (1, num_updated_cache_vectors, seq_length),
|
450 |
+
)
|
451 |
+
attention_mask = combine_masks(pad_mask, attention_mask)
|
452 |
+
|
453 |
+
return key, value, attention_mask
|
454 |
+
|
455 |
+
|
456 |
+
class FlaxWhisperEncoderLayer(nn.Module):
|
457 |
+
config: WhisperConfig
|
458 |
+
dtype: jnp.dtype = jnp.float32
|
459 |
+
params_dtype: jnp.dtype = jnp.float32
|
460 |
+
use_scan: bool = False
|
461 |
+
|
462 |
+
def setup(self) -> None:
|
463 |
+
self.embed_dim = self.config.d_model
|
464 |
+
self.self_attn = FlaxWhisperAttention(
|
465 |
+
config=self.config,
|
466 |
+
embed_dim=self.embed_dim,
|
467 |
+
num_heads=self.config.encoder_attention_heads,
|
468 |
+
dropout=self.config.attention_dropout,
|
469 |
+
dtype=self.dtype,
|
470 |
+
params_dtype=self.params_dtype,
|
471 |
+
)
|
472 |
+
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
473 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
474 |
+
self.activation_fn = ACT2FN[self.config.activation_function]
|
475 |
+
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
476 |
+
self.fc1 = DenseGeneral(
|
477 |
+
self.config.encoder_ffn_dim,
|
478 |
+
dtype=self.dtype,
|
479 |
+
params_dtype=self.params_dtype,
|
480 |
+
kernel_axes=("embed", "mlp"),
|
481 |
+
)
|
482 |
+
self.fc2 = DenseGeneral(
|
483 |
+
self.embed_dim,
|
484 |
+
dtype=self.dtype,
|
485 |
+
params_dtype=self.params_dtype,
|
486 |
+
kernel_axes=("mlp", "embed"),
|
487 |
+
)
|
488 |
+
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
489 |
+
|
490 |
+
def __call__(
|
491 |
+
self,
|
492 |
+
hidden_states: jnp.ndarray,
|
493 |
+
attention_mask: jnp.ndarray,
|
494 |
+
output_attentions: bool = True,
|
495 |
+
deterministic: bool = True,
|
496 |
+
all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
|
497 |
+
) -> Tuple[jnp.ndarray]:
|
498 |
+
if self.use_scan:
|
499 |
+
hidden_states = hidden_states[0]
|
500 |
+
|
501 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
502 |
+
|
503 |
+
residual = hidden_states
|
504 |
+
|
505 |
+
layernorm_output = self.self_attn_layer_norm(hidden_states)
|
506 |
+
layernorm_output = with_sharding_constraint(layernorm_output, ("batch", "length", "embed"))
|
507 |
+
|
508 |
+
attn_output, attn_weights = self.self_attn(hidden_states=layernorm_output, attention_mask=attention_mask)
|
509 |
+
attn_output = self.dropout_layer(attn_output, deterministic=deterministic)
|
510 |
+
attn_output = residual + attn_output
|
511 |
+
attn_output = with_sharding_constraint(attn_output, ("batch", "length", "embed"))
|
512 |
+
|
513 |
+
residual = attn_output
|
514 |
+
|
515 |
+
post_layer_norm = self.final_layer_norm(attn_output)
|
516 |
+
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
|
517 |
+
|
518 |
+
fc1_output = self.activation_fn(self.fc1(post_layer_norm))
|
519 |
+
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
|
520 |
+
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
|
521 |
+
|
522 |
+
hidden_states = self.fc2(fc1_output)
|
523 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
524 |
+
hidden_states = residual + hidden_states
|
525 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
526 |
+
|
527 |
+
outputs = (hidden_states,)
|
528 |
+
|
529 |
+
if output_attentions:
|
530 |
+
outputs += (attn_weights,)
|
531 |
+
|
532 |
+
if self.use_scan:
|
533 |
+
if all_hidden_states is not None:
|
534 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
535 |
+
outputs = (
|
536 |
+
outputs,
|
537 |
+
all_hidden_states,
|
538 |
+
)
|
539 |
+
|
540 |
+
return outputs
|
541 |
+
|
542 |
+
|
543 |
+
class FlaxWhisperEncoderLayerCollection(nn.Module):
|
544 |
+
config: WhisperConfig
|
545 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
546 |
+
params_dtype: jnp.dtype = jnp.float32
|
547 |
+
use_scan: bool = False
|
548 |
+
gradient_checkpointing: bool = False
|
549 |
+
|
550 |
+
@nn.compact
|
551 |
+
def __call__(
|
552 |
+
self,
|
553 |
+
hidden_states,
|
554 |
+
attention_mask,
|
555 |
+
deterministic: bool = True,
|
556 |
+
output_attentions: bool = False,
|
557 |
+
output_hidden_states: bool = False,
|
558 |
+
return_dict: bool = True,
|
559 |
+
):
|
560 |
+
all_attentions = () if output_attentions else None
|
561 |
+
all_hidden_states = () if output_hidden_states else None
|
562 |
+
|
563 |
+
FlaxWhisperEncoderCheckpointLayer = (
|
564 |
+
remat(
|
565 |
+
FlaxWhisperEncoderLayer,
|
566 |
+
static_argnums=(2, 3),
|
567 |
+
prevent_cse=not self.use_scan,
|
568 |
+
)
|
569 |
+
if self.gradient_checkpointing
|
570 |
+
else FlaxWhisperEncoderLayer
|
571 |
+
)
|
572 |
+
|
573 |
+
if self.use_scan:
|
574 |
+
if output_attentions:
|
575 |
+
raise ValueError("Cannot use `scan` with `output_attentions` set to True")
|
576 |
+
|
577 |
+
# nicest behaviour for scan is to let the compiler figure out the correct shapes for the hidden states
|
578 |
+
# so we'll just pass an empty tuple as the carry initializer and hold on to the first hidden states for later
|
579 |
+
input_hidden_states = hidden_states
|
580 |
+
hidden_states = (hidden_states,)
|
581 |
+
|
582 |
+
hidden_states, all_hidden_states = scan_with_axes(
|
583 |
+
FlaxWhisperEncoderCheckpointLayer,
|
584 |
+
variable_axes={"params": 0, "cache": 0},
|
585 |
+
split_rngs={"params": True, "dropout": True},
|
586 |
+
in_axes=(
|
587 |
+
nn.broadcast,
|
588 |
+
nn.broadcast,
|
589 |
+
nn.broadcast,
|
590 |
+
nn.broadcast,
|
591 |
+
),
|
592 |
+
variable_carry="all_hidden_states",
|
593 |
+
length=self.config.encoder_layers,
|
594 |
+
)(
|
595 |
+
self.config,
|
596 |
+
dtype=self.dtype,
|
597 |
+
params_dtype=self.params_dtype,
|
598 |
+
use_scan=True,
|
599 |
+
name="FlaxEncoderScanLayers",
|
600 |
+
)(
|
601 |
+
hidden_states,
|
602 |
+
attention_mask,
|
603 |
+
output_attentions,
|
604 |
+
deterministic,
|
605 |
+
all_hidden_states, # tuple intializer (or None if not using output_hidden_states)
|
606 |
+
)
|
607 |
+
|
608 |
+
# remove the scan dimension
|
609 |
+
hidden_states = hidden_states[0]
|
610 |
+
|
611 |
+
if output_hidden_states:
|
612 |
+
# if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
|
613 |
+
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
|
614 |
+
|
615 |
+
else:
|
616 |
+
for layer_idx in range(self.config.encoder_layers):
|
617 |
+
if output_hidden_states:
|
618 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
619 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
620 |
+
dropout_probability = random.uniform(0, 1)
|
621 |
+
if not deterministic and (dropout_probability < self.config.encoder_layerdrop): # skip the layer
|
622 |
+
layer_outputs = (None, None)
|
623 |
+
else:
|
624 |
+
layer_outputs = FlaxWhisperEncoderCheckpointLayer(
|
625 |
+
self.config,
|
626 |
+
dtype=self.dtype,
|
627 |
+
params_dtype=self.params_dtype,
|
628 |
+
name=str(layer_idx),
|
629 |
+
)(
|
630 |
+
hidden_states,
|
631 |
+
attention_mask,
|
632 |
+
output_attentions,
|
633 |
+
deterministic,
|
634 |
+
)
|
635 |
+
hidden_states = layer_outputs[0]
|
636 |
+
if output_attentions:
|
637 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
638 |
+
|
639 |
+
if output_hidden_states:
|
640 |
+
all_hidden_states += (hidden_states,)
|
641 |
+
|
642 |
+
outputs = (hidden_states, all_hidden_states, all_attentions)
|
643 |
+
|
644 |
+
if not return_dict:
|
645 |
+
return tuple(v for v in outputs if v is not None)
|
646 |
+
|
647 |
+
return FlaxBaseModelOutput(
|
648 |
+
last_hidden_state=hidden_states,
|
649 |
+
hidden_states=all_hidden_states,
|
650 |
+
attentions=all_attentions,
|
651 |
+
)
|
652 |
+
|
653 |
+
|
654 |
+
class FlaxWhisperDecoderLayer(nn.Module):
|
655 |
+
config: WhisperConfig
|
656 |
+
dtype: jnp.dtype = jnp.float32
|
657 |
+
params_dtype: jnp.dtype = jnp.float32
|
658 |
+
use_scan: bool = False
|
659 |
+
|
660 |
+
def setup(self) -> None:
|
661 |
+
self.embed_dim = self.config.d_model
|
662 |
+
self.self_attn = FlaxWhisperAttention(
|
663 |
+
config=self.config,
|
664 |
+
embed_dim=self.embed_dim,
|
665 |
+
num_heads=self.config.decoder_attention_heads,
|
666 |
+
dropout=self.config.attention_dropout,
|
667 |
+
causal=True,
|
668 |
+
dtype=self.dtype,
|
669 |
+
params_dtype=self.params_dtype,
|
670 |
+
)
|
671 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
672 |
+
self.activation_fn = ACT2FN[self.config.activation_function]
|
673 |
+
self.activation_dropout_layer = nn.Dropout(rate=self.config.activation_dropout)
|
674 |
+
|
675 |
+
self.self_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
676 |
+
self.encoder_attn = FlaxWhisperAttention(
|
677 |
+
config=self.config,
|
678 |
+
embed_dim=self.embed_dim,
|
679 |
+
num_heads=self.config.decoder_attention_heads,
|
680 |
+
dropout=self.config.attention_dropout,
|
681 |
+
dtype=self.dtype,
|
682 |
+
params_dtype=self.params_dtype,
|
683 |
+
)
|
684 |
+
self.encoder_attn_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
685 |
+
self.fc1 = DenseGeneral(
|
686 |
+
self.config.decoder_ffn_dim,
|
687 |
+
dtype=self.dtype,
|
688 |
+
params_dtype=self.params_dtype,
|
689 |
+
kernel_axes=("embed", "mlp"),
|
690 |
+
)
|
691 |
+
self.fc2 = DenseGeneral(
|
692 |
+
self.embed_dim,
|
693 |
+
dtype=self.dtype,
|
694 |
+
params_dtype=self.params_dtype,
|
695 |
+
kernel_axes=("mlp", "embed"),
|
696 |
+
)
|
697 |
+
self.final_layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
698 |
+
|
699 |
+
def __call__(
|
700 |
+
self,
|
701 |
+
hidden_states: jnp.ndarray,
|
702 |
+
attention_mask: jnp.ndarray,
|
703 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
704 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
705 |
+
init_cache: bool = False,
|
706 |
+
output_attentions: bool = True,
|
707 |
+
deterministic: bool = True,
|
708 |
+
all_hidden_states=None, # only used when `use_scan=True` -> we have to fetch the hidden states from within the layer
|
709 |
+
) -> Tuple[jnp.ndarray]:
|
710 |
+
if self.use_scan:
|
711 |
+
hidden_states = hidden_states[0]
|
712 |
+
|
713 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
714 |
+
|
715 |
+
residual = hidden_states
|
716 |
+
|
717 |
+
layer_norm_output = self.self_attn_layer_norm(hidden_states)
|
718 |
+
layer_norm_output = with_sharding_constraint(layer_norm_output, ("batch", "length", "embed"))
|
719 |
+
|
720 |
+
# Self Attention
|
721 |
+
self_attn_output, self_attn_weights = self.self_attn(
|
722 |
+
hidden_states=layer_norm_output,
|
723 |
+
attention_mask=attention_mask,
|
724 |
+
init_cache=init_cache,
|
725 |
+
)
|
726 |
+
self_attn_output = self.dropout_layer(self_attn_output, deterministic=deterministic)
|
727 |
+
self_attn_output = residual + self_attn_output
|
728 |
+
self_attn_output = with_sharding_constraint(self_attn_output, ("batch", "length", "embed"))
|
729 |
+
|
730 |
+
# Cross-Attention Block
|
731 |
+
cross_attn_weights = None
|
732 |
+
if encoder_hidden_states is not None:
|
733 |
+
residual = self_attn_output
|
734 |
+
|
735 |
+
encoder_layer_norm_output = self.encoder_attn_layer_norm(self_attn_output)
|
736 |
+
encoder_layer_norm_output = with_sharding_constraint(
|
737 |
+
encoder_layer_norm_output, ("batch", "length", "embed")
|
738 |
+
)
|
739 |
+
|
740 |
+
cross_attn_output, cross_attn_weights = self.encoder_attn(
|
741 |
+
hidden_states=encoder_layer_norm_output,
|
742 |
+
key_value_states=encoder_hidden_states,
|
743 |
+
attention_mask=encoder_attention_mask,
|
744 |
+
)
|
745 |
+
cross_attn_output = self.dropout_layer(cross_attn_output, deterministic=deterministic)
|
746 |
+
cross_attn_output = residual + cross_attn_output
|
747 |
+
cross_attn_output = with_sharding_constraint(cross_attn_output, ("batch", "length", "embed"))
|
748 |
+
|
749 |
+
# Fully Connected
|
750 |
+
residual = cross_attn_output
|
751 |
+
|
752 |
+
post_layer_norm = self.final_layer_norm(cross_attn_output)
|
753 |
+
post_layer_norm = with_sharding_constraint(post_layer_norm, ("batch", "length", "embed"))
|
754 |
+
|
755 |
+
fc1_output = self.activation_fn(self.fc1(post_layer_norm))
|
756 |
+
fc1_output = self.activation_dropout_layer(fc1_output, deterministic=deterministic)
|
757 |
+
fc1_output = with_sharding_constraint(fc1_output, ("batch", "length", "mlp"))
|
758 |
+
|
759 |
+
hidden_states = self.fc2(fc1_output)
|
760 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
761 |
+
hidden_states = residual + hidden_states
|
762 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
763 |
+
|
764 |
+
outputs = (hidden_states,)
|
765 |
+
|
766 |
+
if output_attentions:
|
767 |
+
outputs += (self_attn_weights, cross_attn_weights)
|
768 |
+
|
769 |
+
if self.use_scan:
|
770 |
+
if all_hidden_states is not None:
|
771 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
772 |
+
outputs = (
|
773 |
+
outputs,
|
774 |
+
all_hidden_states,
|
775 |
+
)
|
776 |
+
|
777 |
+
return outputs
|
778 |
+
|
779 |
+
|
780 |
+
class FlaxWhisperDecoderLayerCollection(nn.Module):
|
781 |
+
config: WhisperConfig
|
782 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
783 |
+
params_dtype: jnp.dtype = jnp.float32
|
784 |
+
use_scan: bool = False
|
785 |
+
gradient_checkpointing: bool = False
|
786 |
+
|
787 |
+
@nn.compact
|
788 |
+
def __call__(
|
789 |
+
self,
|
790 |
+
hidden_states,
|
791 |
+
attention_mask,
|
792 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
793 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
794 |
+
deterministic: bool = True,
|
795 |
+
init_cache: bool = False,
|
796 |
+
output_attentions: bool = False,
|
797 |
+
output_hidden_states: bool = False,
|
798 |
+
return_dict: bool = True,
|
799 |
+
):
|
800 |
+
# decoder layers
|
801 |
+
all_hidden_states = () if output_hidden_states else None
|
802 |
+
all_self_attns = () if output_attentions else None
|
803 |
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
804 |
+
|
805 |
+
FlaxWhisperDecoderCheckpointLayer = (
|
806 |
+
remat(
|
807 |
+
FlaxWhisperDecoderLayer,
|
808 |
+
static_argnums=(4, 5, 6),
|
809 |
+
prevent_cse=not self.use_scan,
|
810 |
+
)
|
811 |
+
if self.gradient_checkpointing
|
812 |
+
else FlaxWhisperDecoderLayer
|
813 |
+
)
|
814 |
+
|
815 |
+
if self.use_scan:
|
816 |
+
if output_attentions:
|
817 |
+
raise ValueError("Cannot use `scan` with `output_attentions` set to True")
|
818 |
+
|
819 |
+
input_hidden_states = hidden_states
|
820 |
+
hidden_states = (hidden_states,)
|
821 |
+
|
822 |
+
hidden_states, all_hidden_states = scan_with_axes(
|
823 |
+
FlaxWhisperDecoderCheckpointLayer,
|
824 |
+
variable_axes={"params": 0, "cache": 0},
|
825 |
+
split_rngs={"params": True, "dropout": True},
|
826 |
+
in_axes=(
|
827 |
+
nn.broadcast,
|
828 |
+
nn.broadcast,
|
829 |
+
nn.broadcast,
|
830 |
+
nn.broadcast,
|
831 |
+
nn.broadcast,
|
832 |
+
nn.broadcast,
|
833 |
+
nn.broadcast,
|
834 |
+
),
|
835 |
+
variable_carry="all_hidden_states",
|
836 |
+
length=self.config.decoder_layers,
|
837 |
+
)(
|
838 |
+
self.config,
|
839 |
+
dtype=self.dtype,
|
840 |
+
params_dtype=self.params_dtype,
|
841 |
+
use_scan=True,
|
842 |
+
name="FlaxDecoderScanLayers",
|
843 |
+
)(
|
844 |
+
hidden_states,
|
845 |
+
attention_mask,
|
846 |
+
encoder_hidden_states,
|
847 |
+
encoder_attention_mask,
|
848 |
+
init_cache,
|
849 |
+
output_attentions,
|
850 |
+
deterministic,
|
851 |
+
all_hidden_states,
|
852 |
+
)
|
853 |
+
hidden_states = hidden_states[0]
|
854 |
+
|
855 |
+
if output_hidden_states:
|
856 |
+
# if we're using scan we'll surely be training -> return hidden states as a tensor rather than tuple
|
857 |
+
all_hidden_states = jnp.vstack([input_hidden_states[None, ...], all_hidden_states[0]])
|
858 |
+
|
859 |
+
else:
|
860 |
+
for layer_idx in range(self.config.decoder_layers):
|
861 |
+
if output_hidden_states:
|
862 |
+
all_hidden_states += (hidden_states,)
|
863 |
+
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
864 |
+
dropout_probability = random.uniform(0, 1)
|
865 |
+
if not deterministic and (dropout_probability < self.config.decoder_layerdrop):
|
866 |
+
layer_outputs = (None, None, None)
|
867 |
+
else:
|
868 |
+
layer_outputs = FlaxWhisperDecoderCheckpointLayer(
|
869 |
+
self.config,
|
870 |
+
dtype=self.dtype,
|
871 |
+
params_dtype=self.params_dtype,
|
872 |
+
name=str(layer_idx),
|
873 |
+
)(
|
874 |
+
hidden_states,
|
875 |
+
attention_mask,
|
876 |
+
encoder_hidden_states,
|
877 |
+
encoder_attention_mask,
|
878 |
+
init_cache,
|
879 |
+
output_attentions,
|
880 |
+
deterministic,
|
881 |
+
)
|
882 |
+
|
883 |
+
hidden_states = layer_outputs[0]
|
884 |
+
if output_attentions:
|
885 |
+
all_self_attns += (layer_outputs[1],)
|
886 |
+
|
887 |
+
if encoder_hidden_states is not None:
|
888 |
+
all_cross_attentions += (layer_outputs[2],)
|
889 |
+
|
890 |
+
# add hidden states from the last decoder layer
|
891 |
+
if output_hidden_states:
|
892 |
+
all_hidden_states += (hidden_states,)
|
893 |
+
|
894 |
+
outputs = [
|
895 |
+
hidden_states,
|
896 |
+
all_hidden_states,
|
897 |
+
all_self_attns,
|
898 |
+
all_cross_attentions,
|
899 |
+
]
|
900 |
+
|
901 |
+
if not return_dict:
|
902 |
+
return tuple(v for v in outputs if v is not None)
|
903 |
+
|
904 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
905 |
+
last_hidden_state=hidden_states,
|
906 |
+
hidden_states=all_hidden_states,
|
907 |
+
attentions=all_self_attns,
|
908 |
+
cross_attentions=all_cross_attentions,
|
909 |
+
)
|
910 |
+
|
911 |
+
|
912 |
+
class FlaxWhisperEncoder(nn.Module):
|
913 |
+
config: WhisperConfig
|
914 |
+
dtype: jnp.dtype = jnp.float32
|
915 |
+
params_dtype: jnp.dtype = jnp.float32
|
916 |
+
use_scan: bool = False
|
917 |
+
gradient_checkpointing: bool = False
|
918 |
+
|
919 |
+
def setup(self) -> None:
|
920 |
+
self.conv1 = Conv(
|
921 |
+
self.config.d_model,
|
922 |
+
kernel_size=(3,),
|
923 |
+
padding=1,
|
924 |
+
dtype=self.dtype,
|
925 |
+
params_dtype=self.params_dtype,
|
926 |
+
kernel_axes=("channels", "num_mel", "embed"),
|
927 |
+
)
|
928 |
+
self.conv2 = Conv(
|
929 |
+
self.config.d_model,
|
930 |
+
kernel_size=(3,),
|
931 |
+
strides=2,
|
932 |
+
padding=1,
|
933 |
+
dtype=self.dtype,
|
934 |
+
params_dtype=self.params_dtype,
|
935 |
+
kernel_axes=("channels", "embed", "num_mel"),
|
936 |
+
)
|
937 |
+
|
938 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
939 |
+
|
940 |
+
self.layers = FlaxWhisperEncoderLayerCollection(
|
941 |
+
self.config,
|
942 |
+
dtype=self.dtype,
|
943 |
+
params_dtype=self.params_dtype,
|
944 |
+
use_scan=self.use_scan,
|
945 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
946 |
+
)
|
947 |
+
self.embed_positions = Embed(
|
948 |
+
self.config.max_source_positions,
|
949 |
+
self.config.d_model,
|
950 |
+
dtype=self.dtype,
|
951 |
+
params_dtype=self.params_dtype,
|
952 |
+
)
|
953 |
+
|
954 |
+
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-05, params_dtype=self.params_dtype)
|
955 |
+
|
956 |
+
def __call__(
|
957 |
+
self,
|
958 |
+
input_features: jnp.ndarray,
|
959 |
+
output_attentions: bool = False,
|
960 |
+
output_hidden_states: bool = False,
|
961 |
+
return_dict: bool = True,
|
962 |
+
deterministic: bool = True,
|
963 |
+
) -> Tuple[jnp.ndarray]:
|
964 |
+
if input_features.shape[1:] != (
|
965 |
+
self.config.num_mel_bins,
|
966 |
+
self.config.max_source_positions * 2,
|
967 |
+
):
|
968 |
+
raise ValueError(
|
969 |
+
"input_features.shape[1:], must be equal to (self.config.num_mel_bins,"
|
970 |
+
" self.config.max_source_positions * 2) (got"
|
971 |
+
f" {input_features.shape[1:]}, but should be"
|
972 |
+
f" ({self.config.num_mel_bins},"
|
973 |
+
f" {self.config.max_source_positions * 2}))"
|
974 |
+
)
|
975 |
+
|
976 |
+
input_features = input_features.transpose(0, 2, 1)
|
977 |
+
hidden_states = jax.nn.gelu(self.conv1(input_features), approximate=False)
|
978 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "embed", "num_mel"))
|
979 |
+
hidden_states = jax.nn.gelu(self.conv2(hidden_states), approximate=False)
|
980 |
+
hidden_states = with_sharding_constraint(hidden_states, ("batch", "length", "embed"))
|
981 |
+
|
982 |
+
embed_positions = self.embed_positions(jnp.arange(self.config.max_source_positions))
|
983 |
+
# sinusoidal positional embeddings should not be trained
|
984 |
+
embed_positions = jax.lax.stop_gradient(embed_positions)
|
985 |
+
hidden_states = hidden_states + embed_positions
|
986 |
+
|
987 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
988 |
+
|
989 |
+
outputs = self.layers(
|
990 |
+
hidden_states,
|
991 |
+
attention_mask=None,
|
992 |
+
deterministic=deterministic,
|
993 |
+
output_attentions=output_attentions,
|
994 |
+
output_hidden_states=output_hidden_states,
|
995 |
+
return_dict=return_dict,
|
996 |
+
)
|
997 |
+
|
998 |
+
last_hidden_states = outputs[0]
|
999 |
+
last_hidden_states = self.layer_norm(last_hidden_states)
|
1000 |
+
|
1001 |
+
# update the last element in `hidden_states` after applying `layernorm` above
|
1002 |
+
hidden_states = None
|
1003 |
+
if output_hidden_states:
|
1004 |
+
hidden_states = outputs[1]
|
1005 |
+
if self.use_scan:
|
1006 |
+
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
|
1007 |
+
else:
|
1008 |
+
hidden_states = hidden_states[:-1] + (last_hidden_states,)
|
1009 |
+
|
1010 |
+
if not return_dict:
|
1011 |
+
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
|
1012 |
+
return tuple(v for v in outputs if v is not None)
|
1013 |
+
|
1014 |
+
return FlaxBaseModelOutput(
|
1015 |
+
last_hidden_state=last_hidden_states,
|
1016 |
+
hidden_states=hidden_states,
|
1017 |
+
attentions=outputs.attentions,
|
1018 |
+
)
|
1019 |
+
|
1020 |
+
|
1021 |
+
class FlaxWhisperDecoder(nn.Module):
|
1022 |
+
config: WhisperConfig
|
1023 |
+
dtype: jnp.dtype = jnp.float32
|
1024 |
+
params_dtype: jnp.dtype = jnp.float32
|
1025 |
+
use_scan: bool = False
|
1026 |
+
gradient_checkpointing: bool = False
|
1027 |
+
|
1028 |
+
def setup(self) -> None:
|
1029 |
+
self.embed_tokens = Embed(
|
1030 |
+
self.config.vocab_size,
|
1031 |
+
self.config.d_model,
|
1032 |
+
dtype=self.dtype,
|
1033 |
+
params_dtype=self.params_dtype,
|
1034 |
+
)
|
1035 |
+
self.embed_positions = Embed(
|
1036 |
+
self.config.max_target_positions,
|
1037 |
+
self.config.d_model,
|
1038 |
+
dtype=self.dtype,
|
1039 |
+
params_dtype=self.params_dtype,
|
1040 |
+
)
|
1041 |
+
|
1042 |
+
self.layers = FlaxWhisperDecoderLayerCollection(
|
1043 |
+
self.config,
|
1044 |
+
dtype=self.dtype,
|
1045 |
+
params_dtype=self.params_dtype,
|
1046 |
+
use_scan=self.use_scan,
|
1047 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1048 |
+
)
|
1049 |
+
|
1050 |
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
1051 |
+
|
1052 |
+
self.layer_norm = LayerNorm(dtype=self.dtype, epsilon=1e-5, params_dtype=self.params_dtype)
|
1053 |
+
|
1054 |
+
def __call__(
|
1055 |
+
self,
|
1056 |
+
input_ids: jnp.ndarray,
|
1057 |
+
attention_mask: jnp.ndarray,
|
1058 |
+
position_ids: jnp.ndarray,
|
1059 |
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
1060 |
+
init_cache: bool = False,
|
1061 |
+
output_attentions: bool = False,
|
1062 |
+
output_hidden_states: bool = False,
|
1063 |
+
return_dict: bool = True,
|
1064 |
+
deterministic: bool = True,
|
1065 |
+
) -> Tuple[jnp.ndarray]:
|
1066 |
+
input_embeds = self.embed_tokens(input_ids)
|
1067 |
+
position_embeds = self.embed_positions(position_ids)
|
1068 |
+
|
1069 |
+
hidden_states = input_embeds + position_embeds
|
1070 |
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
1071 |
+
|
1072 |
+
outputs = self.layers(
|
1073 |
+
hidden_states,
|
1074 |
+
attention_mask=attention_mask,
|
1075 |
+
encoder_hidden_states=encoder_hidden_states,
|
1076 |
+
deterministic=deterministic,
|
1077 |
+
init_cache=init_cache,
|
1078 |
+
output_attentions=output_attentions,
|
1079 |
+
output_hidden_states=output_hidden_states,
|
1080 |
+
return_dict=return_dict,
|
1081 |
+
)
|
1082 |
+
|
1083 |
+
last_hidden_states = outputs[0]
|
1084 |
+
last_hidden_states = self.layer_norm(last_hidden_states)
|
1085 |
+
|
1086 |
+
# update the last element in `hidden_states` after applying `layernorm` above
|
1087 |
+
hidden_states = None
|
1088 |
+
if output_hidden_states:
|
1089 |
+
hidden_states = outputs[1]
|
1090 |
+
if self.use_scan:
|
1091 |
+
hidden_states = jnp.vstack([hidden_states[:-1], last_hidden_states[None, ...]])
|
1092 |
+
else:
|
1093 |
+
hidden_states = hidden_states[:-1] + (last_hidden_states,)
|
1094 |
+
|
1095 |
+
if not return_dict:
|
1096 |
+
outputs = (last_hidden_states, hidden_states) + (outputs[2:] if output_hidden_states else outputs[1:])
|
1097 |
+
return tuple(v for v in outputs if v is not None)
|
1098 |
+
|
1099 |
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
1100 |
+
last_hidden_state=last_hidden_states,
|
1101 |
+
hidden_states=hidden_states,
|
1102 |
+
attentions=outputs.attentions,
|
1103 |
+
cross_attentions=outputs.cross_attentions,
|
1104 |
+
)
|
1105 |
+
|
1106 |
+
|
1107 |
+
class FlaxWhisperModule(nn.Module):
|
1108 |
+
config: WhisperConfig
|
1109 |
+
dtype: jnp.dtype = jnp.float32
|
1110 |
+
params_dtype: jnp.dtype = jnp.float32
|
1111 |
+
use_scan: bool = False
|
1112 |
+
gradient_checkpointing: bool = False
|
1113 |
+
|
1114 |
+
def setup(self) -> None:
|
1115 |
+
self.encoder = FlaxWhisperEncoder(
|
1116 |
+
self.config,
|
1117 |
+
dtype=self.dtype,
|
1118 |
+
params_dtype=self.params_dtype,
|
1119 |
+
use_scan=self.use_scan,
|
1120 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1121 |
+
)
|
1122 |
+
self.decoder = FlaxWhisperDecoder(
|
1123 |
+
self.config,
|
1124 |
+
dtype=self.dtype,
|
1125 |
+
params_dtype=self.params_dtype,
|
1126 |
+
use_scan=self.use_scan,
|
1127 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1128 |
+
)
|
1129 |
+
|
1130 |
+
def __call__(
|
1131 |
+
self,
|
1132 |
+
input_features: jnp.ndarray,
|
1133 |
+
decoder_input_ids: jnp.ndarray,
|
1134 |
+
decoder_attention_mask: jnp.ndarray,
|
1135 |
+
decoder_position_ids: jnp.ndarray,
|
1136 |
+
output_attentions: bool = False,
|
1137 |
+
output_hidden_states: bool = False,
|
1138 |
+
freeze_encoder: bool = False,
|
1139 |
+
return_dict: bool = True,
|
1140 |
+
deterministic: bool = True,
|
1141 |
+
):
|
1142 |
+
encoder_outputs = self.encoder(
|
1143 |
+
input_features,
|
1144 |
+
output_attentions=output_attentions,
|
1145 |
+
output_hidden_states=output_hidden_states,
|
1146 |
+
return_dict=return_dict,
|
1147 |
+
deterministic=deterministic,
|
1148 |
+
)
|
1149 |
+
|
1150 |
+
encoder_hidden_states = encoder_outputs[0]
|
1151 |
+
|
1152 |
+
if freeze_encoder:
|
1153 |
+
encoder_hidden_states = jax.lax.stop_gradient(encoder_hidden_states)
|
1154 |
+
|
1155 |
+
decoder_outputs = self.decoder(
|
1156 |
+
input_ids=decoder_input_ids,
|
1157 |
+
attention_mask=decoder_attention_mask,
|
1158 |
+
position_ids=decoder_position_ids,
|
1159 |
+
encoder_hidden_states=encoder_hidden_states,
|
1160 |
+
output_attentions=output_attentions,
|
1161 |
+
output_hidden_states=output_hidden_states,
|
1162 |
+
return_dict=return_dict,
|
1163 |
+
deterministic=deterministic,
|
1164 |
+
)
|
1165 |
+
|
1166 |
+
if not return_dict:
|
1167 |
+
return decoder_outputs + encoder_outputs
|
1168 |
+
|
1169 |
+
return FlaxSeq2SeqModelOutput(
|
1170 |
+
last_hidden_state=decoder_outputs.last_hidden_state,
|
1171 |
+
decoder_hidden_states=decoder_outputs.hidden_states,
|
1172 |
+
decoder_attentions=decoder_outputs.attentions,
|
1173 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1174 |
+
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
|
1175 |
+
encoder_hidden_states=encoder_outputs.hidden_states,
|
1176 |
+
encoder_attentions=encoder_outputs.attentions,
|
1177 |
+
)
|
1178 |
+
|
1179 |
+
def _get_encoder_module(self):
|
1180 |
+
return self.encoder
|
1181 |
+
|
1182 |
+
def _get_decoder_module(self):
|
1183 |
+
return self.decoder
|
1184 |
+
|
1185 |
+
|
1186 |
+
class FlaxWhisperPreTrainedModel(FlaxPreTrainedModel):
|
1187 |
+
config_class = WhisperConfig
|
1188 |
+
base_model_prefix: str = "model"
|
1189 |
+
main_input_name = "input_features"
|
1190 |
+
module_class: nn.Module = None
|
1191 |
+
|
1192 |
+
def __init__(
|
1193 |
+
self,
|
1194 |
+
config: WhisperConfig,
|
1195 |
+
input_shape: Tuple[int, int, int] = None,
|
1196 |
+
seed: int = 0,
|
1197 |
+
dtype: jnp.dtype = jnp.float32,
|
1198 |
+
params_dtype: jnp.dtype = jnp.float32,
|
1199 |
+
_do_init: bool = True,
|
1200 |
+
# Can only use_scan=True in init if loading scanned weights -> need to handle use_scan=True and unrolled weights
|
1201 |
+
use_scan: bool = False,
|
1202 |
+
gradient_checkpointing: bool = False,
|
1203 |
+
**kwargs,
|
1204 |
+
):
|
1205 |
+
self.use_scan = use_scan
|
1206 |
+
self.gradient_checkpointing = gradient_checkpointing
|
1207 |
+
|
1208 |
+
module = self.module_class(
|
1209 |
+
config=config,
|
1210 |
+
dtype=dtype,
|
1211 |
+
params_dtype=params_dtype,
|
1212 |
+
use_scan=use_scan,
|
1213 |
+
gradient_checkpointing=gradient_checkpointing,
|
1214 |
+
**kwargs,
|
1215 |
+
)
|
1216 |
+
|
1217 |
+
if input_shape is None:
|
1218 |
+
input_shape = (1, 80, 2 * config.max_source_positions)
|
1219 |
+
|
1220 |
+
super().__init__(
|
1221 |
+
config,
|
1222 |
+
module,
|
1223 |
+
input_shape=input_shape,
|
1224 |
+
seed=seed,
|
1225 |
+
dtype=dtype,
|
1226 |
+
_do_init=_do_init,
|
1227 |
+
)
|
1228 |
+
|
1229 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
1230 |
+
# init input tensors
|
1231 |
+
input_features = jnp.zeros(input_shape, dtype="f4")
|
1232 |
+
input_features = input_features.at[(..., -1)].set(self.config.eos_token_id)
|
1233 |
+
|
1234 |
+
decoder_input_ids = jnp.zeros((input_shape[0], 1), dtype="i4")
|
1235 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
1236 |
+
|
1237 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1238 |
+
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
1239 |
+
|
1240 |
+
params_rng, dropout_rng = jax.random.split(rng)
|
1241 |
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
1242 |
+
|
1243 |
+
random_params = self.module.init(
|
1244 |
+
rngs,
|
1245 |
+
input_features=input_features,
|
1246 |
+
decoder_input_ids=decoder_input_ids,
|
1247 |
+
decoder_attention_mask=decoder_attention_mask,
|
1248 |
+
decoder_position_ids=decoder_position_ids,
|
1249 |
+
)["params"]
|
1250 |
+
|
1251 |
+
if params is not None:
|
1252 |
+
random_params = flatten_dict(unfreeze(random_params))
|
1253 |
+
params = flatten_dict(unfreeze(params))
|
1254 |
+
for missing_key in self._missing_keys:
|
1255 |
+
params[missing_key] = random_params[missing_key]
|
1256 |
+
self._missing_keys = set()
|
1257 |
+
return freeze(unflatten_dict(params))
|
1258 |
+
else:
|
1259 |
+
return random_params
|
1260 |
+
|
1261 |
+
def enable_gradient_checkpointing(self):
|
1262 |
+
self.gradient_checkpointing = True
|
1263 |
+
self._module = self.module_class(
|
1264 |
+
config=self.config,
|
1265 |
+
dtype=self.dtype,
|
1266 |
+
use_scan=self.use_scan,
|
1267 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1268 |
+
)
|
1269 |
+
|
1270 |
+
def enable_scan(self):
|
1271 |
+
self.use_scan = True
|
1272 |
+
self._module = self.module_class(
|
1273 |
+
config=self.config,
|
1274 |
+
dtype=self.dtype,
|
1275 |
+
use_scan=self.use_scan,
|
1276 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1277 |
+
)
|
1278 |
+
init_fn = partial(self.init_weights, input_shape=self.input_shape)
|
1279 |
+
params_shape_tree = jax.eval_shape(init_fn, self.key)
|
1280 |
+
|
1281 |
+
# get the shape of the parameters
|
1282 |
+
self._params_shape_tree = params_shape_tree
|
1283 |
+
|
1284 |
+
# save required_params as set
|
1285 |
+
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
1286 |
+
|
1287 |
+
# initialize the parameters
|
1288 |
+
if self._is_initialized:
|
1289 |
+
self.params = self.convert_unroll_to_scan(self.params)
|
1290 |
+
|
1291 |
+
def disable_scan(self):
|
1292 |
+
self.use_scan = False
|
1293 |
+
self._module = self.module_class(
|
1294 |
+
config=self.config,
|
1295 |
+
dtype=self.dtype,
|
1296 |
+
use_scan=self.use_scan,
|
1297 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1298 |
+
)
|
1299 |
+
init_fn = partial(self.init_weights, input_shape=self.input_shape)
|
1300 |
+
params_shape_tree = jax.eval_shape(init_fn, self.key)
|
1301 |
+
|
1302 |
+
# get the shape of the parameters
|
1303 |
+
self._params_shape_tree = params_shape_tree
|
1304 |
+
|
1305 |
+
# save required_params as set
|
1306 |
+
self._required_params = set(flatten_dict(unfreeze(params_shape_tree)).keys())
|
1307 |
+
|
1308 |
+
# initialize the parameters
|
1309 |
+
if self._is_initialized:
|
1310 |
+
self.params = self.convert_scan_to_unroll(self.params)
|
1311 |
+
|
1312 |
+
def convert_unroll_to_scan(self, params: Union[Dict, FrozenDict]):
|
1313 |
+
r"""
|
1314 |
+
Convert a `PyTree` of unrolled model parameters to a scanned block of model parameters. This method can be used
|
1315 |
+
to explicitly convert the model parameters to scanned format. This returns a new `params` tree and does not
|
1316 |
+
convert the `params` in place.
|
1317 |
+
|
1318 |
+
To illustrate the workings of this method, take the Flax BERT model. The unrolled structure for the query
|
1319 |
+
projection params is as follows:
|
1320 |
+
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
|
1321 |
+
'q_proj') ... ('bert', 'encoder', 'layer', '23', 'self_attn', 'q_proj')
|
1322 |
+
This method takes each of the `q_proj` matrices for layers (0, ..., 23) and stacks them into a single 'super'
|
1323 |
+
matrix, giving a *single* block of weights for all 24 layers compatible with the scanned model:
|
1324 |
+
('bert', 'encoder', 'layer', 'ScanLayers', 'self_attn', 'q_proj')
|
1325 |
+
|
1326 |
+
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
|
1327 |
+
_do_init=False, it will have to be called explicitly (see example below).
|
1328 |
+
|
1329 |
+
Arguments:
|
1330 |
+
params (`Union[Dict, FrozenDict]`):
|
1331 |
+
A `PyTree` of model parameters.
|
1332 |
+
|
1333 |
+
Examples:
|
1334 |
+
|
1335 |
+
```python
|
1336 |
+
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
|
1337 |
+
|
1338 |
+
>>> # Download model and configuration from huggingface.co
|
1339 |
+
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
|
1340 |
+
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
|
1341 |
+
>>> # we'll first convert to scan format and then back to unrolled
|
1342 |
+
>>> model.enable_scan()
|
1343 |
+
>>> params = model.convert_unroll_to_scan(params)
|
1344 |
+
>>> # now convert back to unrolled
|
1345 |
+
>>> model.disable_scan()
|
1346 |
+
>>> params = model.convert_scan_to_unroll(params)
|
1347 |
+
```"""
|
1348 |
+
if isinstance(params, FrozenDict):
|
1349 |
+
params = unfreeze(params)
|
1350 |
+
|
1351 |
+
params = flatten_dict(params, sep="/")
|
1352 |
+
keys = list(params.keys())
|
1353 |
+
|
1354 |
+
for k in keys:
|
1355 |
+
# Identify all "unrolled" layers formed as part of the FlaxBertLayerCollection
|
1356 |
+
# These params contain the identifier `layer` in their key
|
1357 |
+
if "layers/0" in k:
|
1358 |
+
if "decoder" in k:
|
1359 |
+
block_prefix = "Decoder"
|
1360 |
+
num_hidden_layers = self.config.decoder_layers
|
1361 |
+
else:
|
1362 |
+
block_prefix = "Encoder"
|
1363 |
+
num_hidden_layers = self.config.encoder_layers
|
1364 |
+
|
1365 |
+
# Squash the keys for the N unrolled layers into one single key:
|
1366 |
+
# (layer/0, ..., layer/N) -> layer/FlaxScanLayers
|
1367 |
+
scan_key = k.replace("0", f"Flax{block_prefix}ScanLayers")
|
1368 |
+
stacked_params = []
|
1369 |
+
|
1370 |
+
# Iterate over the unrolled layers (1,...,N)
|
1371 |
+
for i in range(num_hidden_layers):
|
1372 |
+
# Stack the params for the N layers into one super block
|
1373 |
+
# and remove the unrolled layer params on the fly
|
1374 |
+
# -> no memory overhead for conversion!
|
1375 |
+
unrolled_layer = params.pop(k.replace("0", str(i)))
|
1376 |
+
stacked_params.append(unrolled_layer)
|
1377 |
+
|
1378 |
+
params[scan_key] = jnp.stack(stacked_params)
|
1379 |
+
|
1380 |
+
# Finally, unflatten the dict to restore the nested pytree structure
|
1381 |
+
params = unflatten_dict(params, sep="/")
|
1382 |
+
return params
|
1383 |
+
|
1384 |
+
def convert_scan_to_unroll(self, params: Union[Dict, FrozenDict]):
|
1385 |
+
r"""
|
1386 |
+
Convert a `PyTree` of scanned model parameters to an unrolled stack of model parameters. This method can be
|
1387 |
+
used to explicitly convert the model parameters to unrolled format. This returns a new `params` tree and does
|
1388 |
+
not convert the `params` in place.
|
1389 |
+
|
1390 |
+
To illustrate the workings of this method, take the Flax BERT model. The scanned structure for the query
|
1391 |
+
projection (`q_proj`) params is a single, stacked matrix of parameters over all N layers:
|
1392 |
+
('bert', 'encoder', 'layer', 'FlaxScanLayers', 'self_attn', 'q_proj')
|
1393 |
+
|
1394 |
+
This method slices each layer of the `q_proj` scanned matrix into single, standalone layers, and replaces the
|
1395 |
+
scanned matrix of parameteres on the fly:
|
1396 |
+
('bert', 'encoder', 'layer', '0', 'self_attn', 'q_proj') ('bert', 'encoder', 'layer', '1', 'self_attn',
|
1397 |
+
'q_proj') ... ('bert', 'encoder', 'layer', 'N', 'self_attn', 'q_proj')
|
1398 |
+
|
1399 |
+
When enabling scan with _do_init=True (default), this method will be called automatically under the hood. With
|
1400 |
+
_do_init=False, it will have to be called explicitly (see example below).
|
1401 |
+
|
1402 |
+
Arguments:
|
1403 |
+
params (`Union[Dict, FrozenDict]`):
|
1404 |
+
A `PyTree` of model parameters.
|
1405 |
+
|
1406 |
+
Examples:
|
1407 |
+
|
1408 |
+
```python
|
1409 |
+
>>> from distil_whisper import FlaxWhisperForConditionalGeneration
|
1410 |
+
|
1411 |
+
>>> # Download model and configuration from huggingface.co
|
1412 |
+
>>> model, params = FlaxWhisperModel.from_pretrained("openai/whisper-tiny.en", _do_init=False)
|
1413 |
+
>>> # By default, the model params will be in unrolled format. To illustrate the use of this method,
|
1414 |
+
>>> # we'll first convert to scan format and then back to unrolled
|
1415 |
+
>>> model.enable_scan()
|
1416 |
+
>>> params = model.convert_unroll_to_scan(params)
|
1417 |
+
>>> # now convert back to unrolled
|
1418 |
+
>>> model.disable_scan()
|
1419 |
+
>>> params = model.convert_scan_to_unroll(params)
|
1420 |
+
```"""
|
1421 |
+
|
1422 |
+
if isinstance(params, FrozenDict):
|
1423 |
+
params = unfreeze(params)
|
1424 |
+
|
1425 |
+
params = flatten_dict(params, sep="/")
|
1426 |
+
keys = list(params.keys())
|
1427 |
+
|
1428 |
+
for k in keys:
|
1429 |
+
# Identify all "scan" layers formed as part of the FlaxBertLayerCollection
|
1430 |
+
# These params contain the identifier `FlaxScanLayers` in their key
|
1431 |
+
if "FlaxEncoderScanLayers" in k:
|
1432 |
+
# Remove the scan layer from the PyTree of params
|
1433 |
+
scan_layer = params.pop(k)
|
1434 |
+
|
1435 |
+
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
|
1436 |
+
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
|
1437 |
+
for i in range(self.config.encoder_layers):
|
1438 |
+
# Unstack the params for the i-th scan layer to unrolled
|
1439 |
+
# and remove corresponding scan params on the fly
|
1440 |
+
# -> no memory overhead for conversion!
|
1441 |
+
unrolled_key = k.replace("FlaxEncoderScanLayers", str(i))
|
1442 |
+
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
|
1443 |
+
|
1444 |
+
elif "FlaxDecoderScanLayers" in k:
|
1445 |
+
# Remove the scan layer from the PyTree of params
|
1446 |
+
scan_layer = params.pop(k)
|
1447 |
+
|
1448 |
+
# Unroll the key for the stacked scan matrix into N separate keys, indexed by layer number
|
1449 |
+
# layer/FlaxScanLayers -> (layer/0, ..., layer/N)
|
1450 |
+
for i in range(self.config.decoder_layers):
|
1451 |
+
# Unstack the params for the i-th scan layer to unrolled
|
1452 |
+
# and remove corresponding scan params on the fly
|
1453 |
+
# -> no memory overhead for conversion!
|
1454 |
+
unrolled_key = k.replace("FlaxDecoderScanLayers", str(i))
|
1455 |
+
params[unrolled_key], scan_layer = scan_layer[0], scan_layer[1:]
|
1456 |
+
|
1457 |
+
params = unflatten_dict(params, sep="/")
|
1458 |
+
return params
|
1459 |
+
|
1460 |
+
# Copied from transformers.models.whisper.modeling_flax_whisper.FlaxWhisperPreTrainedModel.init_cache
|
1461 |
+
def init_cache(self, batch_size, max_length, encoder_outputs):
|
1462 |
+
r"""
|
1463 |
+
Args:
|
1464 |
+
batch_size (`int`):
|
1465 |
+
batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
|
1466 |
+
max_length (`int`):
|
1467 |
+
maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
|
1468 |
+
cache.
|
1469 |
+
encoder_outputs (`Union[FlaxBaseModelOutput, tuple(tuple(jnp.ndarray)]`):
|
1470 |
+
`encoder_outputs` consists of (`last_hidden_state`, *optional*: `hidden_states`, *optional*:
|
1471 |
+
`attentions`). `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)`, *optional*)
|
1472 |
+
is a sequence of hidden-states at the output of the last layer of the encoder. Used in the
|
1473 |
+
cross-attention of the decoder.
|
1474 |
+
"""
|
1475 |
+
# init input variables to retrieve cache
|
1476 |
+
decoder_input_ids = jnp.ones((batch_size, max_length), dtype="i4")
|
1477 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
1478 |
+
decoder_position_ids = jnp.broadcast_to(
|
1479 |
+
jnp.arange(jnp.atleast_2d(decoder_input_ids).shape[-1]),
|
1480 |
+
decoder_input_ids.shape,
|
1481 |
+
)
|
1482 |
+
|
1483 |
+
def _decoder_forward(
|
1484 |
+
module,
|
1485 |
+
decoder_input_ids,
|
1486 |
+
decoder_attention_mask,
|
1487 |
+
decoder_position_ids,
|
1488 |
+
**kwargs,
|
1489 |
+
):
|
1490 |
+
decoder_module = module._get_decoder_module()
|
1491 |
+
return decoder_module(
|
1492 |
+
decoder_input_ids,
|
1493 |
+
decoder_attention_mask,
|
1494 |
+
decoder_position_ids,
|
1495 |
+
**kwargs,
|
1496 |
+
)
|
1497 |
+
|
1498 |
+
init_variables = self.module.init(
|
1499 |
+
jax.random.PRNGKey(0),
|
1500 |
+
decoder_input_ids=decoder_input_ids,
|
1501 |
+
decoder_attention_mask=decoder_attention_mask,
|
1502 |
+
decoder_position_ids=decoder_position_ids,
|
1503 |
+
encoder_hidden_states=encoder_outputs[0],
|
1504 |
+
init_cache=True,
|
1505 |
+
method=_decoder_forward, # we only need to call the decoder to init the cache
|
1506 |
+
)
|
1507 |
+
return unfreeze(init_variables["cache"])
|
1508 |
+
|
1509 |
+
@add_start_docstrings(WHISPER_ENCODE_INPUTS_DOCSTRING)
|
1510 |
+
@replace_return_docstrings(output_type=FlaxBaseModelOutput, config_class=WhisperConfig)
|
1511 |
+
def encode(
|
1512 |
+
self,
|
1513 |
+
input_features: jnp.ndarray,
|
1514 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
1515 |
+
output_attentions: Optional[bool] = None,
|
1516 |
+
output_hidden_states: Optional[bool] = None,
|
1517 |
+
return_dict: Optional[bool] = None,
|
1518 |
+
train: bool = False,
|
1519 |
+
params: dict = None,
|
1520 |
+
dropout_rng: PRNGKey = None,
|
1521 |
+
**kwargs,
|
1522 |
+
):
|
1523 |
+
r"""
|
1524 |
+
Returns:
|
1525 |
+
|
1526 |
+
Example:
|
1527 |
+
|
1528 |
+
```python
|
1529 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
1530 |
+
>>> from datasets import load_dataset
|
1531 |
+
|
1532 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
1533 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
1534 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
1535 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
1536 |
+
>>> input_features = inputs.input_features
|
1537 |
+
>>> encoder_outputs = model.encode(input_features=input_features)
|
1538 |
+
```"""
|
1539 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1540 |
+
output_hidden_states = (
|
1541 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1542 |
+
)
|
1543 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1544 |
+
|
1545 |
+
# Handle any PRNG if needed
|
1546 |
+
rngs = {}
|
1547 |
+
if dropout_rng is not None:
|
1548 |
+
rngs["dropout"] = dropout_rng
|
1549 |
+
|
1550 |
+
def _encoder_forward(module, input_features, **kwargs):
|
1551 |
+
encode_module = module._get_encoder_module()
|
1552 |
+
return encode_module(input_features, **kwargs)
|
1553 |
+
|
1554 |
+
return self.module.apply(
|
1555 |
+
{"params": params or self.params},
|
1556 |
+
input_features=jnp.array(input_features, dtype="f4"),
|
1557 |
+
output_attentions=output_attentions,
|
1558 |
+
output_hidden_states=output_hidden_states,
|
1559 |
+
return_dict=return_dict,
|
1560 |
+
deterministic=not train,
|
1561 |
+
rngs=rngs,
|
1562 |
+
method=_encoder_forward,
|
1563 |
+
)
|
1564 |
+
|
1565 |
+
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
|
1566 |
+
@replace_return_docstrings(
|
1567 |
+
output_type=FlaxBaseModelOutputWithPastAndCrossAttentions,
|
1568 |
+
config_class=WhisperConfig,
|
1569 |
+
)
|
1570 |
+
def decode(
|
1571 |
+
self,
|
1572 |
+
decoder_input_ids,
|
1573 |
+
encoder_outputs,
|
1574 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1575 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
1576 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
1577 |
+
past_key_values: dict = None,
|
1578 |
+
output_attentions: Optional[bool] = None,
|
1579 |
+
output_hidden_states: Optional[bool] = None,
|
1580 |
+
return_dict: Optional[bool] = None,
|
1581 |
+
train: bool = False,
|
1582 |
+
params: dict = None,
|
1583 |
+
dropout_rng: PRNGKey = None,
|
1584 |
+
):
|
1585 |
+
r"""
|
1586 |
+
Returns:
|
1587 |
+
|
1588 |
+
Example:
|
1589 |
+
|
1590 |
+
```python
|
1591 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
1592 |
+
>>> from datasets import load_dataset
|
1593 |
+
|
1594 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
1595 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
1596 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
1597 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
1598 |
+
>>> input_features = inputs.input_features
|
1599 |
+
>>> encoder_outputs = model.encode(input_features=input_features)
|
1600 |
+
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
1601 |
+
|
1602 |
+
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
1603 |
+
|
1604 |
+
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
1605 |
+
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
1606 |
+
```"""
|
1607 |
+
|
1608 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1609 |
+
output_hidden_states = (
|
1610 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1611 |
+
)
|
1612 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1613 |
+
|
1614 |
+
encoder_hidden_states = encoder_outputs[0]
|
1615 |
+
|
1616 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1617 |
+
if decoder_position_ids is None:
|
1618 |
+
if past_key_values is not None:
|
1619 |
+
raise ValueError("Make sure to provide `decoder_position_ids` when passing" " `past_key_values`.")
|
1620 |
+
|
1621 |
+
if decoder_attention_mask is not None:
|
1622 |
+
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
|
1623 |
+
else:
|
1624 |
+
decoder_position_ids = jnp.broadcast_to(
|
1625 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
1626 |
+
)
|
1627 |
+
|
1628 |
+
if decoder_attention_mask is None:
|
1629 |
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
1630 |
+
|
1631 |
+
# Handle any PRNG if needed
|
1632 |
+
rngs = {}
|
1633 |
+
if dropout_rng is not None:
|
1634 |
+
rngs["dropout"] = dropout_rng
|
1635 |
+
|
1636 |
+
inputs = {"params": params or self.params}
|
1637 |
+
|
1638 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
1639 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
1640 |
+
# it can be changed by FlaxWhisperAttention module
|
1641 |
+
if past_key_values:
|
1642 |
+
inputs["cache"] = past_key_values
|
1643 |
+
mutable = ["cache"]
|
1644 |
+
else:
|
1645 |
+
mutable = False
|
1646 |
+
|
1647 |
+
def _decoder_forward(
|
1648 |
+
module,
|
1649 |
+
decoder_input_ids,
|
1650 |
+
decoder_attention_mask,
|
1651 |
+
decoder_position_ids,
|
1652 |
+
**kwargs,
|
1653 |
+
):
|
1654 |
+
decoder_module = module._get_decoder_module()
|
1655 |
+
return decoder_module(
|
1656 |
+
input_ids=decoder_input_ids,
|
1657 |
+
attention_mask=decoder_attention_mask,
|
1658 |
+
position_ids=decoder_position_ids,
|
1659 |
+
**kwargs,
|
1660 |
+
)
|
1661 |
+
|
1662 |
+
outputs = self.module.apply(
|
1663 |
+
inputs,
|
1664 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
1665 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
1666 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
1667 |
+
encoder_hidden_states=encoder_hidden_states,
|
1668 |
+
output_attentions=output_attentions,
|
1669 |
+
output_hidden_states=output_hidden_states,
|
1670 |
+
return_dict=return_dict,
|
1671 |
+
deterministic=not train,
|
1672 |
+
rngs=rngs,
|
1673 |
+
mutable=mutable,
|
1674 |
+
method=_decoder_forward,
|
1675 |
+
)
|
1676 |
+
|
1677 |
+
# add updated cache to model output
|
1678 |
+
if past_key_values is not None and return_dict:
|
1679 |
+
outputs, past = outputs
|
1680 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
1681 |
+
return outputs
|
1682 |
+
elif past_key_values is not None and not return_dict:
|
1683 |
+
outputs, past = outputs
|
1684 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
1685 |
+
|
1686 |
+
return outputs
|
1687 |
+
|
1688 |
+
@add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING)
|
1689 |
+
def __call__(
|
1690 |
+
self,
|
1691 |
+
input_features: jnp.ndarray,
|
1692 |
+
decoder_input_ids: jnp.ndarray,
|
1693 |
+
attention_mask: Optional[jnp.ndarray] = None,
|
1694 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
1695 |
+
position_ids: Optional[jnp.ndarray] = None,
|
1696 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
1697 |
+
output_attentions: Optional[bool] = None,
|
1698 |
+
output_hidden_states: Optional[bool] = None,
|
1699 |
+
freeze_encoder: Optional[bool] = None,
|
1700 |
+
return_dict: Optional[bool] = None,
|
1701 |
+
train: bool = False,
|
1702 |
+
params: dict = None,
|
1703 |
+
dropout_rng: PRNGKey = None,
|
1704 |
+
):
|
1705 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1706 |
+
output_hidden_states = (
|
1707 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1708 |
+
)
|
1709 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1710 |
+
|
1711 |
+
# prepare decoder inputs
|
1712 |
+
if decoder_position_ids is None:
|
1713 |
+
if decoder_attention_mask is not None:
|
1714 |
+
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
|
1715 |
+
else:
|
1716 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1717 |
+
decoder_position_ids = jnp.broadcast_to(
|
1718 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
1719 |
+
)
|
1720 |
+
if decoder_attention_mask is None:
|
1721 |
+
decoder_attention_mask = jnp.ones_like(decoder_input_ids)
|
1722 |
+
|
1723 |
+
# Handle any PRNG if needed
|
1724 |
+
rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}
|
1725 |
+
|
1726 |
+
return self.module.apply(
|
1727 |
+
{"params": params or self.params},
|
1728 |
+
input_features=jnp.array(input_features, dtype="f4"),
|
1729 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
1730 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
1731 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
1732 |
+
output_attentions=output_attentions,
|
1733 |
+
output_hidden_states=output_hidden_states,
|
1734 |
+
freeze_encoder=freeze_encoder,
|
1735 |
+
return_dict=return_dict,
|
1736 |
+
deterministic=not train,
|
1737 |
+
rngs=rngs,
|
1738 |
+
)
|
1739 |
+
|
1740 |
+
|
1741 |
+
@add_start_docstrings(
|
1742 |
+
("The bare Whisper Model transformer outputting raw hidden-states without any" " specific head on top."),
|
1743 |
+
WHISPER_START_DOCSTRING,
|
1744 |
+
)
|
1745 |
+
class FlaxWhisperModel(FlaxWhisperPreTrainedModel):
|
1746 |
+
config: WhisperConfig
|
1747 |
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
1748 |
+
params_dtype: jnp.dtype = jnp.float32
|
1749 |
+
module_class = FlaxWhisperModule
|
1750 |
+
|
1751 |
+
|
1752 |
+
append_call_sample_docstring(FlaxWhisperModel, _CHECKPOINT_FOR_DOC, FlaxSeq2SeqModelOutput, _CONFIG_FOR_DOC)
|
1753 |
+
|
1754 |
+
|
1755 |
+
class FlaxWhisperForConditionalGenerationModule(nn.Module):
|
1756 |
+
config: WhisperConfig
|
1757 |
+
dtype: jnp.dtype = jnp.float32
|
1758 |
+
params_dtype: jnp.dtype = jnp.float32
|
1759 |
+
use_scan: bool = False
|
1760 |
+
gradient_checkpointing: bool = False
|
1761 |
+
|
1762 |
+
def setup(self) -> None:
|
1763 |
+
self.model = FlaxWhisperModule(
|
1764 |
+
config=self.config,
|
1765 |
+
dtype=self.dtype,
|
1766 |
+
params_dtype=self.params_dtype,
|
1767 |
+
use_scan=self.use_scan,
|
1768 |
+
gradient_checkpointing=self.gradient_checkpointing,
|
1769 |
+
)
|
1770 |
+
self.lm_head = DenseGeneral(
|
1771 |
+
self.config.vocab_size,
|
1772 |
+
use_bias=False,
|
1773 |
+
dtype=self.dtype,
|
1774 |
+
params_dtype=self.params_dtype,
|
1775 |
+
kernel_axes=("embed", "vocab"),
|
1776 |
+
)
|
1777 |
+
|
1778 |
+
def _get_encoder_module(self):
|
1779 |
+
return self.model.encoder
|
1780 |
+
|
1781 |
+
def _get_decoder_module(self):
|
1782 |
+
return self.model.decoder
|
1783 |
+
|
1784 |
+
def __call__(
|
1785 |
+
self,
|
1786 |
+
input_features,
|
1787 |
+
decoder_input_ids,
|
1788 |
+
decoder_attention_mask: jnp.ndarray = None,
|
1789 |
+
decoder_position_ids: jnp.ndarray = None,
|
1790 |
+
position_ids: jnp.ndarray = None,
|
1791 |
+
attention_mask: jnp.ndarray = None,
|
1792 |
+
output_attentions: bool = False,
|
1793 |
+
output_hidden_states: bool = False,
|
1794 |
+
freeze_encoder: bool = False,
|
1795 |
+
return_dict: bool = True,
|
1796 |
+
deterministic: bool = True,
|
1797 |
+
):
|
1798 |
+
outputs = self.model(
|
1799 |
+
input_features=input_features,
|
1800 |
+
decoder_input_ids=decoder_input_ids,
|
1801 |
+
decoder_attention_mask=decoder_attention_mask,
|
1802 |
+
decoder_position_ids=decoder_position_ids,
|
1803 |
+
output_attentions=output_attentions,
|
1804 |
+
output_hidden_states=output_hidden_states,
|
1805 |
+
freeze_encoder=freeze_encoder,
|
1806 |
+
return_dict=return_dict,
|
1807 |
+
deterministic=deterministic,
|
1808 |
+
)
|
1809 |
+
|
1810 |
+
hidden_states = outputs[0]
|
1811 |
+
|
1812 |
+
if self.config.tie_word_embeddings:
|
1813 |
+
shared_embedding = self.model.decoder.embed_tokens.variables["params"]["embedding"]
|
1814 |
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
1815 |
+
else:
|
1816 |
+
lm_logits = self.lm_head(hidden_states)
|
1817 |
+
|
1818 |
+
if not return_dict:
|
1819 |
+
output = (lm_logits,) + outputs[1:]
|
1820 |
+
return output
|
1821 |
+
|
1822 |
+
return FlaxSeq2SeqLMOutput(
|
1823 |
+
logits=lm_logits,
|
1824 |
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
1825 |
+
decoder_attentions=outputs.decoder_attentions,
|
1826 |
+
cross_attentions=outputs.cross_attentions,
|
1827 |
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
1828 |
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
1829 |
+
encoder_attentions=outputs.encoder_attentions,
|
1830 |
+
)
|
1831 |
+
|
1832 |
+
|
1833 |
+
@add_start_docstrings("The Whisper Model with a language modeling head.", WHISPER_START_DOCSTRING)
|
1834 |
+
class FlaxWhisperForConditionalGeneration(FlaxWhisperPreTrainedModel):
|
1835 |
+
module_class = FlaxWhisperForConditionalGenerationModule
|
1836 |
+
|
1837 |
+
@add_start_docstrings(WHISPER_DECODE_INPUTS_DOCSTRING)
|
1838 |
+
@replace_return_docstrings(output_type=FlaxCausalLMOutputWithCrossAttentions, config_class=WhisperConfig)
|
1839 |
+
def decode(
|
1840 |
+
self,
|
1841 |
+
decoder_input_ids,
|
1842 |
+
encoder_outputs,
|
1843 |
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
1844 |
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
1845 |
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
1846 |
+
past_key_values: dict = None,
|
1847 |
+
output_attentions: Optional[bool] = None,
|
1848 |
+
output_hidden_states: Optional[bool] = None,
|
1849 |
+
return_dict: Optional[bool] = None,
|
1850 |
+
train: bool = False,
|
1851 |
+
params: dict = None,
|
1852 |
+
dropout_rng: PRNGKey = None,
|
1853 |
+
):
|
1854 |
+
r"""
|
1855 |
+
Returns:
|
1856 |
+
|
1857 |
+
Example:
|
1858 |
+
|
1859 |
+
```python
|
1860 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
1861 |
+
>>> from datasets import load_dataset
|
1862 |
+
|
1863 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
1864 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
1865 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
1866 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
1867 |
+
>>> input_features = inputs.input_features
|
1868 |
+
>>> encoder_outputs = model.encode(input_features=input_features)
|
1869 |
+
>>> decoder_start_token_id = model.config.decoder_start_token_id
|
1870 |
+
|
1871 |
+
>>> decoder_input_ids = jnp.ones((inputs.input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
1872 |
+
|
1873 |
+
>>> outputs = model.decode(decoder_input_ids, encoder_outputs)
|
1874 |
+
>>> last_decoder_hidden_states = outputs.last_hidden_state
|
1875 |
+
```"""
|
1876 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1877 |
+
output_hidden_states = (
|
1878 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1879 |
+
)
|
1880 |
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
1881 |
+
|
1882 |
+
encoder_hidden_states = encoder_outputs[0]
|
1883 |
+
|
1884 |
+
batch_size, sequence_length = decoder_input_ids.shape
|
1885 |
+
if decoder_position_ids is None:
|
1886 |
+
if past_key_values is not None:
|
1887 |
+
raise ValueError("Make sure to provide `decoder_position_ids` when passing" " `past_key_values`.")
|
1888 |
+
|
1889 |
+
if decoder_attention_mask is not None:
|
1890 |
+
decoder_position_ids = (decoder_attention_mask.cumsum(-1) * decoder_attention_mask) - 1
|
1891 |
+
else:
|
1892 |
+
decoder_position_ids = jnp.broadcast_to(
|
1893 |
+
jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
|
1894 |
+
)
|
1895 |
+
if decoder_attention_mask is None:
|
1896 |
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length), dtype="i4")
|
1897 |
+
|
1898 |
+
# Handle any PRNG if needed
|
1899 |
+
rngs = {}
|
1900 |
+
if dropout_rng is not None:
|
1901 |
+
rngs["dropout"] = dropout_rng
|
1902 |
+
|
1903 |
+
inputs = {"params": params or self.params}
|
1904 |
+
|
1905 |
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
1906 |
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
1907 |
+
# it can be changed by FlaxWhisperAttention module
|
1908 |
+
if past_key_values:
|
1909 |
+
inputs["cache"] = past_key_values
|
1910 |
+
mutable = ["cache"]
|
1911 |
+
else:
|
1912 |
+
mutable = False
|
1913 |
+
|
1914 |
+
def _decoder_forward(
|
1915 |
+
module,
|
1916 |
+
decoder_input_ids,
|
1917 |
+
decoder_attention_mask,
|
1918 |
+
decoder_position_ids,
|
1919 |
+
**kwargs,
|
1920 |
+
):
|
1921 |
+
decoder_module = module._get_decoder_module()
|
1922 |
+
outputs = decoder_module(
|
1923 |
+
input_ids=decoder_input_ids,
|
1924 |
+
attention_mask=decoder_attention_mask,
|
1925 |
+
position_ids=decoder_position_ids,
|
1926 |
+
**kwargs,
|
1927 |
+
)
|
1928 |
+
hidden_states = outputs[0]
|
1929 |
+
|
1930 |
+
if self.config.tie_word_embeddings:
|
1931 |
+
shared_embedding = module.model.decoder.embed_tokens.variables["params"]["embedding"]
|
1932 |
+
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
1933 |
+
else:
|
1934 |
+
lm_logits = module.lm_head(hidden_states)
|
1935 |
+
|
1936 |
+
return lm_logits, outputs
|
1937 |
+
|
1938 |
+
outputs = self.module.apply(
|
1939 |
+
inputs,
|
1940 |
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
1941 |
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
1942 |
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
1943 |
+
encoder_hidden_states=encoder_hidden_states,
|
1944 |
+
output_attentions=output_attentions,
|
1945 |
+
output_hidden_states=output_hidden_states,
|
1946 |
+
return_dict=return_dict,
|
1947 |
+
deterministic=not train,
|
1948 |
+
rngs=rngs,
|
1949 |
+
mutable=mutable,
|
1950 |
+
method=_decoder_forward,
|
1951 |
+
)
|
1952 |
+
|
1953 |
+
if past_key_values is None:
|
1954 |
+
lm_logits, decoder_outputs = outputs
|
1955 |
+
else:
|
1956 |
+
(lm_logits, decoder_outputs), past = outputs
|
1957 |
+
|
1958 |
+
if return_dict:
|
1959 |
+
outputs = FlaxCausalLMOutputWithCrossAttentions(
|
1960 |
+
logits=lm_logits,
|
1961 |
+
hidden_states=decoder_outputs.hidden_states,
|
1962 |
+
attentions=decoder_outputs.attentions,
|
1963 |
+
cross_attentions=decoder_outputs.cross_attentions,
|
1964 |
+
)
|
1965 |
+
else:
|
1966 |
+
outputs = (lm_logits,) + decoder_outputs[1:]
|
1967 |
+
|
1968 |
+
# add updated cache to model output
|
1969 |
+
if past_key_values is not None and return_dict:
|
1970 |
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
1971 |
+
return outputs
|
1972 |
+
elif past_key_values is not None and not return_dict:
|
1973 |
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
1974 |
+
|
1975 |
+
return outputs
|
1976 |
+
|
1977 |
+
def generate(
|
1978 |
+
self,
|
1979 |
+
input_features,
|
1980 |
+
generation_config=None,
|
1981 |
+
logits_processor=None,
|
1982 |
+
return_timestamps=None,
|
1983 |
+
task=None,
|
1984 |
+
language=None,
|
1985 |
+
is_multilingual=None,
|
1986 |
+
**kwargs,
|
1987 |
+
):
|
1988 |
+
if generation_config is None:
|
1989 |
+
generation_config = self.generation_config
|
1990 |
+
|
1991 |
+
if return_timestamps is not None:
|
1992 |
+
generation_config.return_timestamps = return_timestamps
|
1993 |
+
|
1994 |
+
if task is not None:
|
1995 |
+
generation_config.task = task
|
1996 |
+
|
1997 |
+
if is_multilingual is not None:
|
1998 |
+
generation_config.is_multilingual = is_multilingual
|
1999 |
+
|
2000 |
+
if language is not None:
|
2001 |
+
generation_config.language = language
|
2002 |
+
|
2003 |
+
if kwargs is not None and "decoder_input_ids" in kwargs:
|
2004 |
+
decoder_input_length = len(kwargs["decoder_input_ids"])
|
2005 |
+
else:
|
2006 |
+
decoder_input_length = 1
|
2007 |
+
|
2008 |
+
forced_decoder_ids = []
|
2009 |
+
|
2010 |
+
if hasattr(generation_config, "is_multilingual") and generation_config.is_multilingual:
|
2011 |
+
if hasattr(generation_config, "language"):
|
2012 |
+
forced_decoder_ids.append((1, generation_config.lang_to_id[generation_config.language]))
|
2013 |
+
else:
|
2014 |
+
forced_decoder_ids.append((1, None))
|
2015 |
+
|
2016 |
+
if hasattr(generation_config, "task"):
|
2017 |
+
forced_decoder_ids.append((2, generation_config.task_to_id[generation_config.task]))
|
2018 |
+
else:
|
2019 |
+
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
|
2020 |
+
|
2021 |
+
if (
|
2022 |
+
hasattr(generation_config, "return_timestamps") and generation_config.return_timestamps
|
2023 |
+
) or return_timestamps:
|
2024 |
+
logits_processor = [
|
2025 |
+
FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, decoder_input_length)
|
2026 |
+
]
|
2027 |
+
else:
|
2028 |
+
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
|
2029 |
+
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
2030 |
+
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
2031 |
+
|
2032 |
+
if len(forced_decoder_ids) > 0:
|
2033 |
+
generation_config.forced_decoder_ids = forced_decoder_ids
|
2034 |
+
|
2035 |
+
return super().generate(
|
2036 |
+
input_features,
|
2037 |
+
generation_config,
|
2038 |
+
logits_processor=logits_processor,
|
2039 |
+
**kwargs,
|
2040 |
+
)
|
2041 |
+
|
2042 |
+
def pipeline_generate(
|
2043 |
+
self,
|
2044 |
+
input_features,
|
2045 |
+
forced_decoder_ids,
|
2046 |
+
return_timestamps=False,
|
2047 |
+
generation_config=None,
|
2048 |
+
**kwargs,
|
2049 |
+
):
|
2050 |
+
if generation_config is None:
|
2051 |
+
generation_config = self.generation_config
|
2052 |
+
|
2053 |
+
# override the generation config forced decoder ids in preference of the ones we have set
|
2054 |
+
generation_config.forced_decoder_ids = None
|
2055 |
+
|
2056 |
+
logits_processor = FlaxLogitsProcessorList()
|
2057 |
+
|
2058 |
+
logits_processor.append(FlaxStaticForceTokensLogitsProcessor(forced_decoder_ids))
|
2059 |
+
|
2060 |
+
if hasattr(generation_config, "return_timestamps") and return_timestamps:
|
2061 |
+
logits_processor.append(FlaxWhisperTimeStampLogitsProcessor(generation_config, self.config, 1))
|
2062 |
+
|
2063 |
+
return super().generate(
|
2064 |
+
input_features,
|
2065 |
+
generation_config,
|
2066 |
+
logits_processor=logits_processor,
|
2067 |
+
**kwargs,
|
2068 |
+
)
|
2069 |
+
|
2070 |
+
def prepare_inputs_for_generation(
|
2071 |
+
self,
|
2072 |
+
decoder_input_ids,
|
2073 |
+
max_length,
|
2074 |
+
attention_mask: Optional[jax.Array] = None,
|
2075 |
+
decoder_attention_mask: Optional[jax.Array] = None,
|
2076 |
+
encoder_outputs=None,
|
2077 |
+
**kwargs,
|
2078 |
+
):
|
2079 |
+
# initializing the cache
|
2080 |
+
batch_size, seq_length = decoder_input_ids.shape
|
2081 |
+
|
2082 |
+
past_key_values = self.init_cache(batch_size, max_length, encoder_outputs)
|
2083 |
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
2084 |
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
2085 |
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
2086 |
+
extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
|
2087 |
+
if decoder_attention_mask is not None:
|
2088 |
+
position_ids = decoder_attention_mask.cumsum(-1) - 1
|
2089 |
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
|
2090 |
+
else:
|
2091 |
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
2092 |
+
|
2093 |
+
return {
|
2094 |
+
"past_key_values": past_key_values,
|
2095 |
+
"encoder_outputs": encoder_outputs,
|
2096 |
+
"encoder_attention_mask": attention_mask,
|
2097 |
+
"decoder_attention_mask": extended_attention_mask,
|
2098 |
+
"decoder_position_ids": position_ids,
|
2099 |
+
}
|
2100 |
+
|
2101 |
+
def update_inputs_for_generation(self, model_outputs, model_kwargs):
|
2102 |
+
model_kwargs["past_key_values"] = model_outputs.past_key_values
|
2103 |
+
model_kwargs["decoder_position_ids"] = model_kwargs["decoder_position_ids"][:, -1:] + 1
|
2104 |
+
return model_kwargs
|
2105 |
+
|
2106 |
+
|
2107 |
+
FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING = r"""
|
2108 |
+
Returns:
|
2109 |
+
|
2110 |
+
Transcription example:
|
2111 |
+
|
2112 |
+
```python
|
2113 |
+
>>> from transformers import WhisperProcessor, FlaxWhisperForConditionalGeneration
|
2114 |
+
>>> from datasets import load_dataset
|
2115 |
+
|
2116 |
+
>>> processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
|
2117 |
+
>>> model = FlaxWhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en", from_pt=True)
|
2118 |
+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
2119 |
+
>>> inputs = processor(ds[0]["audio"]["array"], return_tensors="np")
|
2120 |
+
>>> input_features = inputs.input_features
|
2121 |
+
>>> generated_ids = model.generate(input_ids=input_features)
|
2122 |
+
>>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
2123 |
+
>>> transcription
|
2124 |
+
' Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
|
2125 |
+
```
|
2126 |
+
"""
|
2127 |
+
|
2128 |
+
overwrite_call_docstring(
|
2129 |
+
FlaxWhisperForConditionalGeneration,
|
2130 |
+
WHISPER_INPUTS_DOCSTRING + FLAX_WHISPER_CONDITIONAL_GENERATION_DOCSTRING,
|
2131 |
+
)
|
2132 |
+
append_replace_return_docstrings(
|
2133 |
+
FlaxWhisperForConditionalGeneration,
|
2134 |
+
output_type=FlaxSeq2SeqLMOutput,
|
2135 |
+
config_class=_CONFIG_FOR_DOC,
|
2136 |
+
)
|
distil_whisper/partitioner.py
ADDED
@@ -0,0 +1,965 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 The T5X Authors.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
"""Utilities for partitioning."""
|
16 |
+
|
17 |
+
import abc
|
18 |
+
import collections
|
19 |
+
import dataclasses
|
20 |
+
import typing
|
21 |
+
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
22 |
+
|
23 |
+
import cached_property
|
24 |
+
import jax
|
25 |
+
import numpy as np
|
26 |
+
from absl import logging
|
27 |
+
from flax import traverse_util
|
28 |
+
from flax.linen import partitioning as flax_partitioning
|
29 |
+
from jax import numpy as jnp
|
30 |
+
from jax import random
|
31 |
+
from jax.experimental import multihost_utils
|
32 |
+
from jax.experimental.mesh_utils import create_hybrid_device_mesh
|
33 |
+
from jax.experimental.pjit import pjit as jax_pjit
|
34 |
+
from jax.sharding import Mesh, PartitionSpec
|
35 |
+
|
36 |
+
|
37 |
+
JaxDevice = Any
|
38 |
+
TpuMesh = Tuple[int, int, int, int] # (x, y, z, num_cores).
|
39 |
+
OtherMesh = Tuple[int, int]
|
40 |
+
HardwareMesh = Union[TpuMesh, OtherMesh]
|
41 |
+
PyTreeDef = type(jax.tree_util.tree_structure(None))
|
42 |
+
TrainState = Any
|
43 |
+
LogicalAxisRules = Sequence[Tuple[str, Optional[str]]]
|
44 |
+
|
45 |
+
if typing.TYPE_CHECKING: # See b/163639353
|
46 |
+
cached_property = property # pylint: disable=invalid-name
|
47 |
+
else:
|
48 |
+
cached_property = cached_property.cached_property
|
49 |
+
|
50 |
+
|
51 |
+
class AxisNames(tuple):
|
52 |
+
"""Tuple of strings specifying name for each axis.
|
53 |
+
|
54 |
+
We create a separate class for this so JAX's pytree utilities can distinguish
|
55 |
+
it from a tuple that should be treated as a pytree, instead treating it as a
|
56 |
+
leaf.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __new__(cls, *names):
|
60 |
+
return tuple.__new__(AxisNames, names)
|
61 |
+
|
62 |
+
def __repr__(self):
|
63 |
+
return "AxisNames%s" % tuple.__repr__(self)
|
64 |
+
|
65 |
+
|
66 |
+
# pjit wrappers for cpu fallback.
|
67 |
+
# ----------------------------------------------------------------------------
|
68 |
+
# TODO(levskaya): This function is now no different than jax_pjit, but callers
|
69 |
+
# currently depend on `backend` argument
|
70 |
+
def pjit(
|
71 |
+
fun: Callable, # pylint: disable=g-bare-generic
|
72 |
+
in_axis_resources,
|
73 |
+
out_axis_resources,
|
74 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
75 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
76 |
+
backend: Optional[str] = None,
|
77 |
+
):
|
78 |
+
"""Wrapper for pjit."""
|
79 |
+
del backend
|
80 |
+
return jax_pjit(
|
81 |
+
fun,
|
82 |
+
in_axis_resources,
|
83 |
+
out_axis_resources,
|
84 |
+
static_argnums=static_argnums,
|
85 |
+
donate_argnums=donate_argnums,
|
86 |
+
)
|
87 |
+
|
88 |
+
|
89 |
+
# pjit wrappers for cpu fallback.
|
90 |
+
# -----------------------------------------------------------------------------
|
91 |
+
# TODO(levskaya): upstream this fallback behavior to jax pjit.
|
92 |
+
def pjit_with_cpu_fallback(
|
93 |
+
fun: Callable, # pylint: disable=g-bare-generic
|
94 |
+
in_axis_resources,
|
95 |
+
out_axis_resources,
|
96 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
97 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
98 |
+
backend: Optional[str] = None,
|
99 |
+
):
|
100 |
+
"""Wrapper for pjit that calls normal jit on cpu."""
|
101 |
+
if jax.devices(backend)[0].platform == "cpu":
|
102 |
+
return jax.jit(fun, static_argnums=static_argnums, donate_argnums=donate_argnums)
|
103 |
+
else:
|
104 |
+
return jax_pjit(
|
105 |
+
fun,
|
106 |
+
in_axis_resources,
|
107 |
+
out_axis_resources,
|
108 |
+
static_argnums=static_argnums,
|
109 |
+
donate_argnums=donate_argnums,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def with_sharding_constraint(x, axis_resources):
|
114 |
+
"""Wrapper for pjit with_sharding_constraint, no-op on cpu or outside pjit."""
|
115 |
+
if jax.devices()[0].platform == "cpu" or not global_mesh_defined():
|
116 |
+
return x
|
117 |
+
else:
|
118 |
+
return jax.experimental.pjit.with_sharding_constraint(x, axis_resources)
|
119 |
+
|
120 |
+
|
121 |
+
# pjit Mesh creation functions.
|
122 |
+
# -----------------------------------------------------------------------------
|
123 |
+
def bounds_from_last_device(last_device: JaxDevice) -> HardwareMesh:
|
124 |
+
"""Get the bound from the given last device."""
|
125 |
+
# Must be passed the device at the highest-coordinate corner of the
|
126 |
+
# relevant mesh, which is a requirement we know is satisfied by the last
|
127 |
+
# device in jax.devices().
|
128 |
+
if hasattr(last_device, "coords"):
|
129 |
+
x, y, z = last_device.coords
|
130 |
+
return x + 1, y + 1, z + 1, last_device.core_on_chip + 1
|
131 |
+
else:
|
132 |
+
# On non-TPU platforms, the "mesh" is hosts x devices per host in order
|
133 |
+
# to take advantage of faster within-host interconnect.
|
134 |
+
return jax.host_count(), jax.local_device_count()
|
135 |
+
|
136 |
+
|
137 |
+
def get_coords(device: JaxDevice) -> HardwareMesh:
|
138 |
+
"""Returns the coordinates of the given device."""
|
139 |
+
if hasattr(device, "coords"):
|
140 |
+
return (*device.coords, device.core_on_chip)
|
141 |
+
return (device.process_index, device.id % jax.local_device_count())
|
142 |
+
|
143 |
+
|
144 |
+
def global_mesh_defined():
|
145 |
+
"""Checks if global xmap/pjit mesh resource environment is defined."""
|
146 |
+
maps_env = jax.experimental.maps.thread_resources.env
|
147 |
+
return maps_env.physical_mesh.devices.shape != () # pylint: disable=g-explicit-bool-comparison
|
148 |
+
|
149 |
+
|
150 |
+
def get_mesh(
|
151 |
+
model_parallel_submesh: HardwareMesh,
|
152 |
+
input_devices: Sequence[JaxDevice] = (),
|
153 |
+
input_local_devices: Sequence[JaxDevice] = (),
|
154 |
+
tile_by_host_if_needed: bool = True,
|
155 |
+
backend: Optional[str] = None,
|
156 |
+
) -> Mesh:
|
157 |
+
"""Construct an xmap/pjit Mesh for the given model-parallel submesh.
|
158 |
+
|
159 |
+
The resulting mesh has two resource axes: 'model', with the provided submesh
|
160 |
+
shape, and 'data', which covers the rest of the mesh.
|
161 |
+
|
162 |
+
Args:
|
163 |
+
model_parallel_submesh: a HardwareMesh spec, namely (x,y,z,core) on TPU for
|
164 |
+
a single model-parallel replica's "tile" in the physical device mesh. The
|
165 |
+
first three elements (`x`, `y`, and `z`) should be factors of the pod
|
166 |
+
slice; e.g., if you are using df_4x8, then `x` should be a factor of 4
|
167 |
+
(one of 1, 2, 4), `y` should be a factor of 8 (one of 1, 2, 4, 8), and `z`
|
168 |
+
must be 1, because TPU v3 slices are only 2D. `z` can be >1 for TPU v4
|
169 |
+
(and maybe later TPUs) that allow 3D slices. `core` is the number of cores
|
170 |
+
to use from each TPU node. As communication is usually fastest inside the
|
171 |
+
same node, if you need a tile of more than 1 core, then
|
172 |
+
you should first increase `core`: e.g., for TPU v3, (1,1,1,2) is better
|
173 |
+
than (2,1,1,1). To pick a good spec, try a few possible values until you
|
174 |
+
get high TPU utilization.
|
175 |
+
input_devices: the devices to use, will use jax.devices() if this is not
|
176 |
+
set.
|
177 |
+
input_local_devices: the local devices to use, will use jax.local_devices()
|
178 |
+
if this is not set.
|
179 |
+
tile_by_host_if_needed: JAX currently requires that the parts of any sharded
|
180 |
+
array that are located on one host's local devices form a single
|
181 |
+
contiguous slice. A best effort will be made to achieve this without
|
182 |
+
"tiling" the device assignment over hosts (which can reduce XLA collective
|
183 |
+
performance). If this flag is True, then the device assignment will be
|
184 |
+
tiled over hosts if necessary to satisfy this constraint and create a
|
185 |
+
buildable mesh; if false, mesh construction will fail instead.
|
186 |
+
backend: get devices from the pinned backend, if specified. This is
|
187 |
+
useful for explicitly specifying the devices other than relying on
|
188 |
+
jax_platform_name.
|
189 |
+
|
190 |
+
Returns:
|
191 |
+
A xmap / pjit Mesh containing the virtual device mesh with data, model axes.
|
192 |
+
"""
|
193 |
+
input_devices = input_devices or jax.devices(backend)
|
194 |
+
input_local_devices = input_local_devices or jax.local_devices(0, backend)
|
195 |
+
# Sort input_devices based on coords, as backends might not return devices
|
196 |
+
# in order.
|
197 |
+
last_device = sorted(input_devices, key=get_coords)[-1]
|
198 |
+
last_input_local_devices = sorted(input_local_devices, key=get_coords)[-1]
|
199 |
+
logging.info(
|
200 |
+
"last device coords : %r\nlast local device coords: %r",
|
201 |
+
get_coords(last_device),
|
202 |
+
get_coords(last_input_local_devices),
|
203 |
+
)
|
204 |
+
global_hardware_mesh = bounds_from_last_device(last_device)
|
205 |
+
mesh_ndim = len(global_hardware_mesh)
|
206 |
+
local_hardware_mesh = bounds_from_last_device(last_input_local_devices)
|
207 |
+
mesh_err = (
|
208 |
+
f"each dimension of the model parallel submesh {model_parallel_submesh} "
|
209 |
+
"must be a factor of the corresponding dimension of the global device "
|
210 |
+
f"mesh {global_hardware_mesh}"
|
211 |
+
)
|
212 |
+
assert not any(g % m for g, m in zip(global_hardware_mesh, model_parallel_submesh)), mesh_err
|
213 |
+
assert not any(g % l for g, l in zip(global_hardware_mesh, local_hardware_mesh))
|
214 |
+
devices = np.empty(global_hardware_mesh, dtype=object)
|
215 |
+
for device in input_devices:
|
216 |
+
device_coords = get_coords(device)
|
217 |
+
devices[device_coords] = device
|
218 |
+
tile_by_host = tile_by_host_if_needed
|
219 |
+
if len(global_hardware_mesh) == 4:
|
220 |
+
# enable contiguous local chunks without host tiling by making Z major
|
221 |
+
global_hardware_mesh = typing.cast(Tuple[int, int, int, int], global_hardware_mesh)
|
222 |
+
model_parallel_submesh = typing.cast(Tuple[int, int, int, int], model_parallel_submesh)
|
223 |
+
gx, gy, gz, gc = global_hardware_mesh
|
224 |
+
mx, my, mz, mc = model_parallel_submesh
|
225 |
+
if (mx == gx > 1 and my == mz == 1) or (mx == 1 and my == gy > 1 and mz == gz > 1):
|
226 |
+
logging.info("ensuring YZ plane has a Z-major device order")
|
227 |
+
# YZ should be ZY
|
228 |
+
assert mc == gc, (mc, gc)
|
229 |
+
global_hardware_mesh = gx, gz, gy, gc
|
230 |
+
model_parallel_submesh = mx, mz, my, mc
|
231 |
+
devices = devices.swapaxes(1, 2)
|
232 |
+
tile_by_host = False
|
233 |
+
if (my == gy > 1 and mx == mz == 1) or (my == 1 and mx == gx > 1 and mz == gz > 1):
|
234 |
+
logging.info("ensuring XZ plane has a Z-major device order")
|
235 |
+
# XZ should be ZX
|
236 |
+
assert mc == gc, (mc, gc)
|
237 |
+
global_hardware_mesh = gz, gy, gx, gc
|
238 |
+
model_parallel_submesh = mz, my, mx, mc
|
239 |
+
devices = devices.swapaxes(0, 2)
|
240 |
+
tile_by_host = False
|
241 |
+
if tile_by_host:
|
242 |
+
logging.warning(
|
243 |
+
"Tiling device assignment mesh by hosts, which may lead to "
|
244 |
+
"reduced XLA collective performance. To avoid this, modify "
|
245 |
+
"the model parallel submesh or run with more tasks per host."
|
246 |
+
)
|
247 |
+
tile_err = (
|
248 |
+
"to tile the mesh by hosts, each dimension of the model parallel "
|
249 |
+
"submesh must be either a factor or a multiple of the corresponding "
|
250 |
+
"dimension of the per-host submesh"
|
251 |
+
)
|
252 |
+
|
253 |
+
def dh_dd_mh_md(g: int, m: int, l: int) -> Tuple[int, int, int, int]:
|
254 |
+
"""Split a global mesh dimension into four tiling components.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
g: global mesh bounds dimension size
|
258 |
+
m: model-parallel submesh bounds dimension size
|
259 |
+
l: local submesh bounds dimension size
|
260 |
+
|
261 |
+
Returns:
|
262 |
+
The resulting tuple divides the dimension into the hosts component of
|
263 |
+
the data-parallel submesh, the devices component of the data-parallel
|
264 |
+
submesh, the hosts component of the model-parallel submesh, and the
|
265 |
+
devices component of the model-parallel submesh.
|
266 |
+
"""
|
267 |
+
d = g // m
|
268 |
+
if m >= l:
|
269 |
+
assert not m % l, tile_err
|
270 |
+
return (d, 1, m // l, l)
|
271 |
+
else:
|
272 |
+
assert not l % m, tile_err
|
273 |
+
return (d // (l // m), l // m, 1, m)
|
274 |
+
|
275 |
+
# e.g. [(x_data_hosts, x_data_devs, x_model_hosts, x_model_devs), ...]
|
276 |
+
dh_dd_mh_md_tups = map(
|
277 |
+
dh_dd_mh_md,
|
278 |
+
global_hardware_mesh,
|
279 |
+
model_parallel_submesh,
|
280 |
+
local_hardware_mesh,
|
281 |
+
)
|
282 |
+
# reshape to e.g. (x_dh, x_dd, x_mh, x_md, y_dh, ...)
|
283 |
+
devices = devices.reshape(*(s for t in dh_dd_mh_md_tups for s in t)) # pylint: disable=g-complex-comprehension
|
284 |
+
# TODO(jekbradbury): reorder local subgroups for ring locality
|
285 |
+
# Transpose to [data_host], [data_device], [model_host], [model_device]
|
286 |
+
# block ordering e.g. (x_dh, y_dh, ..., x_dd, y_dd, ...)
|
287 |
+
devices = devices.transpose(
|
288 |
+
*(4 * i for i in range(mesh_ndim)),
|
289 |
+
*(4 * i + 1 for i in range(mesh_ndim)),
|
290 |
+
*(4 * i + 2 for i in range(mesh_ndim)),
|
291 |
+
*(4 * i + 3 for i in range(mesh_ndim)),
|
292 |
+
)
|
293 |
+
else:
|
294 |
+
# e.g. [(x_data, x_model), (y_data, y_model), ...]
|
295 |
+
model_data_tups = [(g // m, m) for g, m in zip(global_hardware_mesh, model_parallel_submesh)]
|
296 |
+
# reshape to e.g. (x_data, x_model, y_data, y_model...)
|
297 |
+
devices = devices.reshape(*(s for t in model_data_tups for s in t)) # pylint: disable=g-complex-comprehension
|
298 |
+
# TODO(jekbradbury): reorder small subgroups for ring locality
|
299 |
+
# transpose to e.g. (x_data, y_data, ..., x_model, ...)
|
300 |
+
devices = devices.transpose(*(2 * i for i in range(mesh_ndim)), *(2 * i + 1 for i in range(mesh_ndim)))
|
301 |
+
# reshape to (data, model)
|
302 |
+
devices = devices.reshape(-1, np.prod(model_parallel_submesh))
|
303 |
+
global_mesh = Mesh(devices, ["data", "model"])
|
304 |
+
logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
|
305 |
+
logging.info("global_mesh devices: %s", global_mesh.devices)
|
306 |
+
logging.info("global_mesh devices shape: %s", global_mesh.devices.shape)
|
307 |
+
return global_mesh
|
308 |
+
|
309 |
+
|
310 |
+
def get_cpu_mesh() -> Mesh:
|
311 |
+
"""Trivial mesh for CPU Testing."""
|
312 |
+
devices = np.empty((jax.host_count(), jax.local_device_count()), dtype=object)
|
313 |
+
for device in jax.devices():
|
314 |
+
devices[device.process_index, device.id % jax.local_device_count()] = device
|
315 |
+
return Mesh(devices, ["data", "model"])
|
316 |
+
|
317 |
+
|
318 |
+
def get_gpu_mesh(num_partitions: int) -> Mesh:
|
319 |
+
"""Mesh for GPUs that preferentially places 'model' on NVLink."""
|
320 |
+
nvlink_size = jax.local_device_count()
|
321 |
+
dcn_size = jax.process_count()
|
322 |
+
nvlink_mp = min(num_partitions, nvlink_size)
|
323 |
+
nvlink_dp, extra1 = divmod(nvlink_size, nvlink_mp)
|
324 |
+
dcn_mp, extra2 = divmod(num_partitions, nvlink_mp)
|
325 |
+
assert not (extra1 or extra2), (
|
326 |
+
"number of partitions on GPU must be a factor" " or multiple of the number of local devices"
|
327 |
+
)
|
328 |
+
dcn_dp = dcn_size // dcn_mp
|
329 |
+
|
330 |
+
devices = create_hybrid_device_mesh(
|
331 |
+
mesh_shape=[nvlink_dp, nvlink_mp],
|
332 |
+
dcn_mesh_shape=[dcn_dp, dcn_mp],
|
333 |
+
process_is_granule=True,
|
334 |
+
)
|
335 |
+
|
336 |
+
global_mesh = Mesh(devices, ["data", "model"])
|
337 |
+
logging.info("global_mesh axis_names: %s", global_mesh.axis_names)
|
338 |
+
logging.info("global_mesh devices: %s", global_mesh.devices)
|
339 |
+
return global_mesh
|
340 |
+
|
341 |
+
|
342 |
+
def default_mesh(
|
343 |
+
num_partitions: int,
|
344 |
+
model_parallel_submesh: Optional[HardwareMesh] = None,
|
345 |
+
backend: Optional[str] = None,
|
346 |
+
) -> Mesh:
|
347 |
+
"""Attempt to return a default mesh for simple cases.
|
348 |
+
|
349 |
+
Args:
|
350 |
+
num_partitions: number of partitions to use, will be ignored if
|
351 |
+
model_parallel_submesh is provided.
|
352 |
+
model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use as
|
353 |
+
the model-parallel device tile.
|
354 |
+
backend: get devices from the pinned backend, if specified. This is useful
|
355 |
+
for explicitly specifying the devices other than relying on
|
356 |
+
jax_platform_name.
|
357 |
+
|
358 |
+
Returns:
|
359 |
+
xmap/pjit 2D Mesh with 'data', 'model' mesh axes.
|
360 |
+
"""
|
361 |
+
last_device = jax.devices(backend)[-1]
|
362 |
+
platform = last_device.platform
|
363 |
+
device_kind = last_device.device_kind
|
364 |
+
bounds = bounds_from_last_device(last_device)
|
365 |
+
|
366 |
+
if model_parallel_submesh:
|
367 |
+
return get_mesh(model_parallel_submesh, backend=backend)
|
368 |
+
|
369 |
+
if platform == "cpu":
|
370 |
+
return get_cpu_mesh()
|
371 |
+
elif platform == "gpu":
|
372 |
+
return get_gpu_mesh(num_partitions)
|
373 |
+
|
374 |
+
mps = None
|
375 |
+
if device_kind in ("TPU v2", "TPU v3"):
|
376 |
+
if num_partitions == 1:
|
377 |
+
mps = (1, 1, 1, 1)
|
378 |
+
elif num_partitions == 2:
|
379 |
+
mps = (1, 1, 1, 2)
|
380 |
+
elif num_partitions == 4:
|
381 |
+
mps = (2, 1, 1, 2)
|
382 |
+
elif num_partitions == 8:
|
383 |
+
mps = (2, 2, 1, 2)
|
384 |
+
elif num_partitions == 16:
|
385 |
+
mps = (4, 2, 1, 2)
|
386 |
+
# assume the use of megacore on TPU v4
|
387 |
+
elif (device_kind == "TPU v4" or device_kind == "TPU v4 lite") and bounds[3] == 1:
|
388 |
+
if num_partitions == 1:
|
389 |
+
mps = (1, 1, 1, 1)
|
390 |
+
elif num_partitions == 2:
|
391 |
+
mps = (1, 2, 1, 1)
|
392 |
+
elif num_partitions == 4:
|
393 |
+
if bounds[0] >= 4:
|
394 |
+
mps = (4, 1, 1, 1)
|
395 |
+
else:
|
396 |
+
mps = (2, 2, 1, 1)
|
397 |
+
elif num_partitions == 8:
|
398 |
+
if bounds[2] >= 8:
|
399 |
+
mps = (1, 1, 8, 1)
|
400 |
+
else:
|
401 |
+
mps = (4, 2, 1, 1)
|
402 |
+
elif num_partitions == 16:
|
403 |
+
if bounds[2] >= 16:
|
404 |
+
mps = (1, 1, 16, 1)
|
405 |
+
elif bounds[0] >= 8:
|
406 |
+
mps = (8, 2, 1, 1)
|
407 |
+
elif bounds[0] >= 4:
|
408 |
+
mps = (4, 4, 1, 1)
|
409 |
+
else:
|
410 |
+
mps = (2, 2, 4, 1)
|
411 |
+
|
412 |
+
if mps is None:
|
413 |
+
raise ValueError(
|
414 |
+
"No default mesh for this configuration: specify " "config.model_parallel_submesh explicitly."
|
415 |
+
)
|
416 |
+
return get_mesh(mps, backend=backend)
|
417 |
+
|
418 |
+
|
419 |
+
# Data chunking helper.
|
420 |
+
# -----------------------------------------------------------------------------
|
421 |
+
@dataclasses.dataclass
|
422 |
+
class LocalChunkInfo:
|
423 |
+
# The logical slice of an array located on this host's local devices.
|
424 |
+
slice: Tuple[slice, ...]
|
425 |
+
# A unique index for this host/local chunk among chunks with the same slice.
|
426 |
+
replica_id: int
|
427 |
+
|
428 |
+
|
429 |
+
class LocalChunker:
|
430 |
+
"""Utility class to aid chunking of sharded arrays in multihost settings."""
|
431 |
+
|
432 |
+
def __init__(self, global_mesh: Mesh):
|
433 |
+
self.global_mesh = global_mesh
|
434 |
+
local_mesh = global_mesh.local_mesh
|
435 |
+
first_local_device = local_mesh.devices.reshape(-1)[0]
|
436 |
+
host_location = collections.OrderedDict(
|
437 |
+
zip(
|
438 |
+
global_mesh.shape.keys(),
|
439 |
+
list(zip(*np.nonzero(global_mesh.devices == first_local_device)))[0],
|
440 |
+
)
|
441 |
+
)
|
442 |
+
self.num_chunks = collections.OrderedDict()
|
443 |
+
self.chunk_ids = collections.OrderedDict()
|
444 |
+
self.mesh_axes = list(global_mesh.shape.keys())
|
445 |
+
for mesh_axis in self.mesh_axes:
|
446 |
+
num_devices_per_chunk = local_mesh.shape[mesh_axis]
|
447 |
+
self.num_chunks[mesh_axis] = global_mesh.shape[mesh_axis] // num_devices_per_chunk
|
448 |
+
self.chunk_ids[mesh_axis] = host_location[mesh_axis] // num_devices_per_chunk
|
449 |
+
|
450 |
+
def get_local_chunk_info(
|
451 |
+
self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
|
452 |
+
) -> LocalChunkInfo:
|
453 |
+
"""Get the local chunk info for a given array shape and sharded axes.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
global_shape: the global, unsharded shape of the array to chunk.
|
457 |
+
mesh_axes: a sequence of names (or None) of equal rank to `global_shape`
|
458 |
+
that specifies which mesh dimensions the array is sharded along.
|
459 |
+
|
460 |
+
Returns:
|
461 |
+
LocalChunkInfo containing the logical slices of the array found on this
|
462 |
+
host's local devices, as well as the replica index for this chunk among
|
463 |
+
chunks with the same slice. The latter is used to determine which
|
464 |
+
host should write this chunk during checkpointing.
|
465 |
+
"""
|
466 |
+
local_slice = [slice(None) for dim in global_shape]
|
467 |
+
sharded_mesh_axes = set()
|
468 |
+
for i, (mesh_axis, size) in enumerate(zip(mesh_axes, global_shape)):
|
469 |
+
if not mesh_axis:
|
470 |
+
continue
|
471 |
+
sharded_mesh_axes.add(mesh_axis)
|
472 |
+
if not isinstance(mesh_axis, str):
|
473 |
+
raise NotImplementedError("TODO(jekbradbury)")
|
474 |
+
chunk_id = self.chunk_ids[mesh_axis]
|
475 |
+
chunk_size = size // self.num_chunks[mesh_axis]
|
476 |
+
local_slice[i] = slice(chunk_id * chunk_size, (chunk_id + 1) * chunk_size)
|
477 |
+
|
478 |
+
replicated_mesh_axes = [mesh_axis for mesh_axis in self.mesh_axes if mesh_axis not in sharded_mesh_axes]
|
479 |
+
replica_id = 0
|
480 |
+
for mesh_axis in replicated_mesh_axes:
|
481 |
+
chunk_id = self.chunk_ids[mesh_axis]
|
482 |
+
replica_id = replica_id * self.num_chunks[mesh_axis] + chunk_id
|
483 |
+
|
484 |
+
return LocalChunkInfo(tuple(local_slice), replica_id)
|
485 |
+
|
486 |
+
|
487 |
+
def standard_logical_axis_rules(
|
488 |
+
activation_partitioning_dims: int = 1,
|
489 |
+
parameter_partitioning_dims: int = 1,
|
490 |
+
additional_rules: Optional[LogicalAxisRules] = None,
|
491 |
+
) -> LogicalAxisRules:
|
492 |
+
"""Default sharding rules for T5X model in terms of logical axis names.
|
493 |
+
|
494 |
+
Args:
|
495 |
+
activation_partitioning_dims: enables 2-D activation sharding when set to 2.
|
496 |
+
parameter_partitioning_dims: enables 2-D parameter sharding when set to 2.
|
497 |
+
additional_rules: additional rules (a sequence of tuples) that will be
|
498 |
+
appended to the standard rules.
|
499 |
+
|
500 |
+
Returns:
|
501 |
+
Sequence of logical axis rules
|
502 |
+
"""
|
503 |
+
logging.info(
|
504 |
+
"`activation_partitioning_dims` = %d, `parameter_partitioning_dims` = %d",
|
505 |
+
activation_partitioning_dims,
|
506 |
+
parameter_partitioning_dims,
|
507 |
+
)
|
508 |
+
|
509 |
+
if activation_partitioning_dims == 1 and parameter_partitioning_dims == 1:
|
510 |
+
rules = [
|
511 |
+
("batch", "data"),
|
512 |
+
("vocab", "model"),
|
513 |
+
("embed", None),
|
514 |
+
("mlp", "model"),
|
515 |
+
("heads", "model"),
|
516 |
+
("kv", None),
|
517 |
+
("joined_kv", "model"), # joined heads+kv dim in 2D attn param layouts
|
518 |
+
]
|
519 |
+
elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 1:
|
520 |
+
rules = [
|
521 |
+
("batch", "data"),
|
522 |
+
("vocab", "model"),
|
523 |
+
("mlp", "model"),
|
524 |
+
("heads", "model"),
|
525 |
+
("kv", None),
|
526 |
+
("joined_kv", "model"),
|
527 |
+
("embed", "model"),
|
528 |
+
]
|
529 |
+
elif activation_partitioning_dims == 1 and parameter_partitioning_dims == 2:
|
530 |
+
rules = [
|
531 |
+
("batch", "data"),
|
532 |
+
("vocab", "model"),
|
533 |
+
("mlp", "model"),
|
534 |
+
("heads", "model"),
|
535 |
+
("kv", None),
|
536 |
+
("joined_kv", "model"),
|
537 |
+
("embed", "data"),
|
538 |
+
]
|
539 |
+
elif activation_partitioning_dims == 2 and parameter_partitioning_dims == 2:
|
540 |
+
rules = [
|
541 |
+
("batch", "data"),
|
542 |
+
("vocab", "model"),
|
543 |
+
("mlp", "model"),
|
544 |
+
("heads", "model"),
|
545 |
+
("kv", None),
|
546 |
+
("joined_kv", "model"),
|
547 |
+
("embed", "model"),
|
548 |
+
("embed", "data"),
|
549 |
+
]
|
550 |
+
else:
|
551 |
+
raise ValueError(
|
552 |
+
f"`activation_partitioning_dims` = {activation_partitioning_dims} "
|
553 |
+
f"`parameter_partitioning_dims` = {parameter_partitioning_dims} "
|
554 |
+
"is not supported."
|
555 |
+
)
|
556 |
+
|
557 |
+
# Add the common rules for the replicated logical axes names.
|
558 |
+
replicated_rules = [
|
559 |
+
("relpos_buckets", None),
|
560 |
+
("abspos_buckets", None),
|
561 |
+
("length", None),
|
562 |
+
("layers", None),
|
563 |
+
("stack", None),
|
564 |
+
("mlp_activations", None),
|
565 |
+
]
|
566 |
+
rules.extend(replicated_rules)
|
567 |
+
|
568 |
+
if additional_rules:
|
569 |
+
rules.extend(additional_rules)
|
570 |
+
|
571 |
+
return rules
|
572 |
+
|
573 |
+
|
574 |
+
# NB: This needs to be top-level for the jax compilation cache.
|
575 |
+
def _id_fn(x, ix):
|
576 |
+
"""Identity function for copying parameters to the devices, sharded."""
|
577 |
+
# A pure identity such as `lambda x, *: x` can get optimized away, so we
|
578 |
+
# include a random.split as a cheap function that cannot be optimized away.
|
579 |
+
y = random.split(random.PRNGKey(jnp.array(ix, dtype=jnp.uint32)))
|
580 |
+
return x, y
|
581 |
+
|
582 |
+
|
583 |
+
@dataclasses.dataclass
|
584 |
+
class DataLayout:
|
585 |
+
"""Represents data layout for the partitioned model."""
|
586 |
+
|
587 |
+
batch_size: int
|
588 |
+
shard_id: int
|
589 |
+
num_shards: int
|
590 |
+
is_first_host_in_replica_set: bool
|
591 |
+
|
592 |
+
|
593 |
+
PartitionedCallable = Callable[..., Any]
|
594 |
+
CompiledPartitionedCallable = Callable[..., Any]
|
595 |
+
|
596 |
+
|
597 |
+
class BasePartitioner(metaclass=abc.ABCMeta):
|
598 |
+
"""Interface for partitioning computations across hardware devices."""
|
599 |
+
|
600 |
+
def __init__(
|
601 |
+
self,
|
602 |
+
num_partitions: Optional[int] = None,
|
603 |
+
model_parallel_submesh: Optional[HardwareMesh] = None,
|
604 |
+
params_on_devices: bool = True,
|
605 |
+
backend: Optional[str] = None,
|
606 |
+
):
|
607 |
+
"""Configures the partitioner.
|
608 |
+
|
609 |
+
Args:
|
610 |
+
num_partitions: the number of partitions to use. Ignored if
|
611 |
+
`model_parallel_submesh` is provided.
|
612 |
+
model_parallel_submesh: 4-tuple that specifies the x,y,z,c submesh to use
|
613 |
+
as the model-parallel device tile. This submesh is used for the larger
|
614 |
+
of the two parameter dimensions, and, if 2-D activation sharding is
|
615 |
+
enabled, for the model dimension of activations. The rest of the mesh is
|
616 |
+
used for data parallelism and, if 2-D parameter sharding is enabled, the
|
617 |
+
other parameter dimension.
|
618 |
+
params_on_devices: whether to keep the params on devices, if False -
|
619 |
+
params stay in the host memory. Note that some partitioners might ignore
|
620 |
+
this setting, for example if they don't support storing all params on
|
621 |
+
device memory.
|
622 |
+
backend: get devices from the pinned backend, if specified. This is useful
|
623 |
+
for explicitly specifying the devices other than relying on
|
624 |
+
jax_platform_name.
|
625 |
+
"""
|
626 |
+
|
627 |
+
if not num_partitions and not model_parallel_submesh:
|
628 |
+
raise ValueError("At least one of `num_partitions` or " "`model_parallel_submesh` must be set.")
|
629 |
+
|
630 |
+
if model_parallel_submesh is not None and len(model_parallel_submesh) != 4:
|
631 |
+
logging.error(
|
632 |
+
(
|
633 |
+
"`model_parallel_submesh` must be either None or a 4-tuple. Got"
|
634 |
+
" `model_parallel_submesh`=%s. A ValueError will be raised"
|
635 |
+
" beginning March 1, 2022."
|
636 |
+
),
|
637 |
+
model_parallel_submesh,
|
638 |
+
)
|
639 |
+
|
640 |
+
if bool(num_partitions) and bool(model_parallel_submesh):
|
641 |
+
logging.error(
|
642 |
+
(
|
643 |
+
"At most one of `num_partitions` or `model_parallel_submesh` can be"
|
644 |
+
" set. Got `num_partitions=%s` and `model_parallel_submesh`=%s. A"
|
645 |
+
" ValueError will be raised beginning March 21, 2022."
|
646 |
+
),
|
647 |
+
num_partitions,
|
648 |
+
model_parallel_submesh,
|
649 |
+
)
|
650 |
+
|
651 |
+
self._num_partitions = num_partitions
|
652 |
+
self._model_parallel_submesh = model_parallel_submesh
|
653 |
+
self._params_on_devices = params_on_devices
|
654 |
+
self._data_axis = "data"
|
655 |
+
self._backend = backend
|
656 |
+
|
657 |
+
@property
|
658 |
+
def mesh(self) -> Mesh:
|
659 |
+
raise NotImplementedError
|
660 |
+
|
661 |
+
@property
|
662 |
+
def data_partition_spec(self) -> PartitionSpec:
|
663 |
+
return PartitionSpec(self._data_axis)
|
664 |
+
|
665 |
+
def get_data_layout(self, batch_size: Optional[int] = None, host_index: Optional[int] = None) -> DataLayout:
|
666 |
+
"""Returns filled `DataLayout` based on the partitioned model layout.
|
667 |
+
|
668 |
+
Args:
|
669 |
+
batch_size: if set, indicates the requested batch size. The exception will
|
670 |
+
be raised if this batch size is not compatible with the layout. If not
|
671 |
+
set, the batch size is inferred from the layout.
|
672 |
+
host_index: indicates the host index to use for the calculations, if not
|
673 |
+
set - use JAX-provided one. Should be in [0, num_hosts) interval and the
|
674 |
+
order should match the order of corresponding CPU devices in
|
675 |
+
`jax.devices()`.
|
676 |
+
|
677 |
+
Returns:
|
678 |
+
Filled `DataLayout` structure.
|
679 |
+
"""
|
680 |
+
if host_index is not None:
|
681 |
+
raise NotImplementedError("Explicit host_index is not yet implemented.")
|
682 |
+
if self._data_axis is None:
|
683 |
+
return DataLayout(
|
684 |
+
batch_size=batch_size,
|
685 |
+
shard_id=0,
|
686 |
+
num_shards=1,
|
687 |
+
is_first_host_in_replica_set=(jax.process_index() == 0),
|
688 |
+
)
|
689 |
+
mesh_size = self._local_chunker.global_mesh.shape[self._data_axis]
|
690 |
+
batch_size = batch_size or mesh_size
|
691 |
+
if batch_size % mesh_size:
|
692 |
+
raise ValueError(
|
693 |
+
f"Batch size ({batch_size}) must be divisible by corresponding " f"mesh size ({mesh_size})."
|
694 |
+
)
|
695 |
+
num_shards = self._local_chunker.num_chunks[self._data_axis]
|
696 |
+
if batch_size % num_shards:
|
697 |
+
raise ValueError(f"Batch size ({batch_size}) must be divisible by number of " f"replicas ({num_shards}).")
|
698 |
+
replica_id = self._local_chunker.get_local_chunk_info((batch_size,), [self._data_axis]).replica_id
|
699 |
+
return DataLayout(
|
700 |
+
batch_size=int(batch_size),
|
701 |
+
shard_id=int(self._local_chunker.chunk_ids[self._data_axis]),
|
702 |
+
num_shards=int(num_shards),
|
703 |
+
is_first_host_in_replica_set=(replica_id == 0),
|
704 |
+
)
|
705 |
+
|
706 |
+
def get_local_chunk_info(
|
707 |
+
self, global_shape: Tuple[int, ...], mesh_axes: Sequence[Optional[str]]
|
708 |
+
) -> LocalChunkInfo:
|
709 |
+
"""Returns the local chunk info for a given array shape and sharded axes."""
|
710 |
+
return self._local_chunker.get_local_chunk_info(global_shape, mesh_axes)
|
711 |
+
|
712 |
+
@property
|
713 |
+
def params_on_devices(self):
|
714 |
+
return self._params_on_devices
|
715 |
+
|
716 |
+
def move_params_to_devices(self, train_state: TrainState, train_state_axes: TrainState) -> TrainState:
|
717 |
+
"""Moves the optimizer parameters to devices."""
|
718 |
+
p_id_fn = self.partition(
|
719 |
+
_id_fn,
|
720 |
+
in_axis_resources=(train_state_axes, None),
|
721 |
+
out_axis_resources=(train_state_axes, None),
|
722 |
+
donate_argnums=(0,),
|
723 |
+
)
|
724 |
+
if jax.config.jax_array and jax.process_count() > 1:
|
725 |
+
train_state = multihost_utils.host_local_array_to_global_array(train_state, self.mesh, train_state_axes)
|
726 |
+
train_state, _ = p_id_fn(train_state, jnp.ones((), dtype=jnp.uint32))
|
727 |
+
return train_state
|
728 |
+
|
729 |
+
@property
|
730 |
+
@abc.abstractmethod
|
731 |
+
def _local_chunker(self):
|
732 |
+
"""Returns the chunker that matches the parameters of this partitioner."""
|
733 |
+
raise NotImplementedError
|
734 |
+
|
735 |
+
def get_logical_axes(self, train_state: TrainState) -> TrainState:
|
736 |
+
"""Returns a copy of TrainState with Optional[AxisNames] as leaves."""
|
737 |
+
# By default, return None for the logical axes.
|
738 |
+
return train_state.restore_state(jax.tree_map(lambda x: None, train_state.state_dict()))
|
739 |
+
|
740 |
+
def get_mesh_axes(self, train_state: TrainState) -> TrainState:
|
741 |
+
"""Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
|
742 |
+
raise NotImplementedError
|
743 |
+
|
744 |
+
@abc.abstractmethod
|
745 |
+
def partition(
|
746 |
+
self,
|
747 |
+
fn: Callable, # pylint: disable=g-bare-generic
|
748 |
+
in_axis_resources,
|
749 |
+
out_axis_resources,
|
750 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
751 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
752 |
+
) -> PartitionedCallable:
|
753 |
+
"""Partitions the computation using partitioner-specific implementation.
|
754 |
+
|
755 |
+
Args:
|
756 |
+
fn: the function to partition.
|
757 |
+
in_axis_resources: Pytree of structure matching that of arguments to `fn`,
|
758 |
+
with all actual arguments replaced by resource assignment
|
759 |
+
specifications. It is also valid to specify a pytree prefix (e.g. one
|
760 |
+
value in place of a whole subtree), in which case the leaves get
|
761 |
+
broadcast to all values in that subtree.
|
762 |
+
The valid resource assignment specifications are:
|
763 |
+
`None`: in which case the value will be replicated on all devices
|
764 |
+
`PartitionSpec`: a tuple of length at most equal to the rank of the
|
765 |
+
partitioned value. Each element can be a `None`, a mesh axis or a
|
766 |
+
tuple of mesh axes, and specifies the set of resources assigned to
|
767 |
+
partition the value's dimension matching its position in the spec.
|
768 |
+
out_axis_resources: Like `in_axis_resources`, but specifies resource
|
769 |
+
assignment for function outputs.
|
770 |
+
static_argnums: an optional int or collection of ints that specify which
|
771 |
+
positional arguments to treat as static (compile-time constant) in the
|
772 |
+
partitioned function.
|
773 |
+
donate_argnums: an optional int or collection of ints that specify which
|
774 |
+
argument buffers are "donated" to the computation. It is safe to donate
|
775 |
+
argument buffers if you no longer need them once the computation has
|
776 |
+
finished.
|
777 |
+
|
778 |
+
Returns:
|
779 |
+
A partitioned version of the input function.
|
780 |
+
"""
|
781 |
+
raise NotImplementedError
|
782 |
+
|
783 |
+
@abc.abstractmethod
|
784 |
+
def compile(self, partitioned_fn: PartitionedCallable, *args) -> CompiledPartitionedCallable:
|
785 |
+
"""Compiles and returns the partitioned function, or the original.
|
786 |
+
|
787 |
+
Args:
|
788 |
+
partitioned_fn: The partitioned function.
|
789 |
+
*args: Sample arguments to the partitioned function matching the input
|
790 |
+
shapes that will be passed to the compiled function.
|
791 |
+
|
792 |
+
Returns:
|
793 |
+
The compiled function, or the original if this partitioner does not
|
794 |
+
support compilation.
|
795 |
+
"""
|
796 |
+
raise NotImplementedError
|
797 |
+
|
798 |
+
|
799 |
+
class PjittedFnWithContext(PartitionedCallable):
|
800 |
+
"""Wraps pjitted function to apply the appropriate contexts."""
|
801 |
+
|
802 |
+
def __init__(
|
803 |
+
self,
|
804 |
+
pjitted_fn,
|
805 |
+
partition_mesh: Mesh,
|
806 |
+
logical_axis_rules: flax_partitioning.LogicalRules = (),
|
807 |
+
):
|
808 |
+
self._pjitted_fn = pjitted_fn
|
809 |
+
self._mesh = partition_mesh
|
810 |
+
self._logical_axis_rules = logical_axis_rules
|
811 |
+
|
812 |
+
def __call__(self, *args):
|
813 |
+
with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
|
814 |
+
return self._pjitted_fn(*args)
|
815 |
+
|
816 |
+
def lower(self, *args):
|
817 |
+
with Mesh(self._mesh.devices, self._mesh.axis_names), flax_partitioning.axis_rules(self._logical_axis_rules):
|
818 |
+
return self._pjitted_fn.lower(*args)
|
819 |
+
|
820 |
+
|
821 |
+
class BasePjitPartitioner(BasePartitioner):
|
822 |
+
"""Partitioner that uses T5X version of jax.pjit."""
|
823 |
+
|
824 |
+
@cached_property
|
825 |
+
def _local_chunker(self) -> LocalChunker:
|
826 |
+
return LocalChunker(self.mesh)
|
827 |
+
|
828 |
+
@cached_property
|
829 |
+
def mesh(self) -> Mesh:
|
830 |
+
return default_mesh(self._num_partitions, self._model_parallel_submesh, self._backend)
|
831 |
+
|
832 |
+
def partition(
|
833 |
+
self,
|
834 |
+
fn: Callable, # pylint: disable=g-bare-generic
|
835 |
+
in_axis_resources,
|
836 |
+
out_axis_resources,
|
837 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
838 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
839 |
+
) -> PjittedFnWithContext:
|
840 |
+
pjitted = pjit(
|
841 |
+
fn,
|
842 |
+
in_axis_resources=in_axis_resources,
|
843 |
+
out_axis_resources=out_axis_resources,
|
844 |
+
static_argnums=static_argnums,
|
845 |
+
donate_argnums=donate_argnums,
|
846 |
+
backend=self._backend,
|
847 |
+
)
|
848 |
+
|
849 |
+
return PjittedFnWithContext(pjitted, self.mesh)
|
850 |
+
|
851 |
+
def compile(self, partitioned_fn: PjittedFnWithContext, *args) -> CompiledPartitionedCallable:
|
852 |
+
return partitioned_fn.lower(*args).compile()
|
853 |
+
|
854 |
+
|
855 |
+
class PjitPartitioner(BasePjitPartitioner):
|
856 |
+
"""Partitioner that uses named axes and jax.pjit."""
|
857 |
+
|
858 |
+
def __init__(
|
859 |
+
self,
|
860 |
+
num_partitions: Optional[int] = None,
|
861 |
+
model_parallel_submesh: Optional[HardwareMesh] = None,
|
862 |
+
params_on_devices: bool = True,
|
863 |
+
backend: Optional[str] = None,
|
864 |
+
logical_axis_rules: Optional[LogicalAxisRules] = None,
|
865 |
+
use_cpu_pjit: Optional[bool] = False,
|
866 |
+
):
|
867 |
+
"""PjitPartitioner constructor.
|
868 |
+
|
869 |
+
See https://github.com/google-research/text-to-text-transfer-transformer/blob/main/README.mdx/usage/partitioning for details.
|
870 |
+
|
871 |
+
Args:
|
872 |
+
num_partitions: an integer that specifies the size of the model parallel
|
873 |
+
submesh to be automatically selected for the current topology. See
|
874 |
+
`model_parallel_submesh` for details on how this submesh is used.
|
875 |
+
Mutually exlusive with `model_parallel_submesh`.
|
876 |
+
model_parallel_submesh: is a 4-tuple that specifies the `(x, y, z, c)`
|
877 |
+
submesh model-parallel device tile, an axis of accelerator parallelism
|
878 |
+
orthogonal to data parallelism. Array axes in a model's parameters or
|
879 |
+
activations can be sharded over this submesh using axis rules (see
|
880 |
+
`logical_axis_rules`) that map them to 'model'. The effective number of
|
881 |
+
model sub-partitions is equal to `np.prod(model_parallel_submesh)` and
|
882 |
+
must evenly divide the total number of devices (i.e.,
|
883 |
+
`jax.device_count() % np.prod(model_parallel_submesh) == 0`). The rest
|
884 |
+
of the TPU mesh is the data parallel submesh, providing
|
885 |
+
`jax.device_count() // np.prod(model_parallel_submesh)` partitions. It
|
886 |
+
is used for data (batch) parallelism and to shard other array axes that
|
887 |
+
are mapped to 'data'. This argument is mutually exclusive with
|
888 |
+
`num_partitions`.
|
889 |
+
params_on_devices: whether to keep the params on devices, if False -
|
890 |
+
params stay in the host memory. Note that some partitioners might ignore
|
891 |
+
this setting, for example if they don't support storing all params on
|
892 |
+
device memory.
|
893 |
+
backend: get devices from the pinned backend, if specified. This is
|
894 |
+
useful for explicitly specifying the devices other than relying on
|
895 |
+
jax_platform_name.
|
896 |
+
logical_axis_rules: a priority-ordered sequence of KV tuples that maps
|
897 |
+
logical axis names to either `None` (not sharded), 'model' (to shard
|
898 |
+
across the model-parallel submesh), or 'data' (to shard across the
|
899 |
+
data-parallel submesh).
|
900 |
+
use_cpu_pjit: enables wrapper function for pjit which just jits the
|
901 |
+
function if using CPU backend.
|
902 |
+
"""
|
903 |
+
super().__init__(
|
904 |
+
num_partitions=num_partitions,
|
905 |
+
model_parallel_submesh=model_parallel_submesh,
|
906 |
+
params_on_devices=params_on_devices,
|
907 |
+
backend=backend,
|
908 |
+
)
|
909 |
+
if logical_axis_rules is None:
|
910 |
+
logical_axis_rules = standard_logical_axis_rules()
|
911 |
+
self._logical_axis_rules = tuple(logical_axis_rules)
|
912 |
+
(self._data_axis,) = flax_partitioning.logical_to_mesh_axes(["batch"], logical_axis_rules)
|
913 |
+
self._use_cpu_pjit = use_cpu_pjit
|
914 |
+
|
915 |
+
def partition(
|
916 |
+
self,
|
917 |
+
fn: Callable, # pylint: disable=g-bare-generic
|
918 |
+
in_axis_resources,
|
919 |
+
out_axis_resources,
|
920 |
+
static_argnums: Union[int, Sequence[int]] = (),
|
921 |
+
donate_argnums: Union[int, Sequence[int]] = (),
|
922 |
+
) -> PjittedFnWithContext:
|
923 |
+
"""Partitions the function using jax.pjit."""
|
924 |
+
if self._use_cpu_pjit:
|
925 |
+
pjit_fn = pjit_with_cpu_fallback
|
926 |
+
else:
|
927 |
+
pjit_fn = pjit
|
928 |
+
pjitted = pjit_fn(
|
929 |
+
fn,
|
930 |
+
in_axis_resources=in_axis_resources,
|
931 |
+
out_axis_resources=out_axis_resources,
|
932 |
+
static_argnums=static_argnums,
|
933 |
+
donate_argnums=donate_argnums,
|
934 |
+
backend=self._backend,
|
935 |
+
)
|
936 |
+
|
937 |
+
return PjittedFnWithContext(pjitted, self.mesh, self._logical_axis_rules)
|
938 |
+
|
939 |
+
@property
|
940 |
+
def logical_axis_rules(self):
|
941 |
+
"""Returns the logical axis rules."""
|
942 |
+
return self._logical_axis_rules
|
943 |
+
|
944 |
+
def get_logical_axes(self, train_state: TrainState) -> TrainState:
|
945 |
+
"""Returns a copy of TrainState with Optional[AxisNames] as leaves."""
|
946 |
+
return train_state.as_logical_axes()
|
947 |
+
|
948 |
+
def get_mesh_axes(self, train_state: TrainState) -> TrainState:
|
949 |
+
"""Returns a copy of TrainState with Optional[PartitionSpecs] as leaves."""
|
950 |
+
logical_axes = self.get_logical_axes(train_state)
|
951 |
+
|
952 |
+
def _logical_to_mesh_axes(param_name, logical_axes):
|
953 |
+
if logical_axes is None:
|
954 |
+
return None
|
955 |
+
elif logical_axes is traverse_util.empty_node:
|
956 |
+
return traverse_util.empty_node
|
957 |
+
try:
|
958 |
+
return flax_partitioning.logical_to_mesh_axes(logical_axes, self._logical_axis_rules)
|
959 |
+
except ValueError as e:
|
960 |
+
raise ValueError(f"Failed to map logical axes for {param_name}") from e
|
961 |
+
|
962 |
+
flat_logical_axes = traverse_util.flatten_dict(logical_axes.state_dict(), keep_empty_nodes=True, sep="/")
|
963 |
+
flat_mesh_axes = {k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items()}
|
964 |
+
|
965 |
+
return logical_axes.restore_state(traverse_util.unflatten_dict(flat_mesh_axes, sep="/"))
|
distil_whisper/pipeline.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
"""Whisper JAX pipeline compatible with Distil Whisper checkpoints. Copied from https://github.com/sanchit-gandhi/whisper-jax/blob/main/whisper_jax/pipeline.py"""
|
17 |
+
|
18 |
+
import math
|
19 |
+
|
20 |
+
import jax
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import numpy as np
|
23 |
+
import requests
|
24 |
+
import torch
|
25 |
+
from flax import jax_utils
|
26 |
+
from flax.core.frozen_dict import freeze
|
27 |
+
from flax.training.common_utils import shard
|
28 |
+
from transformers import WhisperFeatureExtractor, WhisperTokenizerFast
|
29 |
+
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
30 |
+
from transformers.pipelines.audio_utils import ffmpeg_read
|
31 |
+
from transformers.utils import logging
|
32 |
+
|
33 |
+
from .modeling_flax_whisper import FlaxWhisperForConditionalGeneration
|
34 |
+
|
35 |
+
|
36 |
+
logger = logging.get_logger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
class FlaxWhisperFeatureExtractor(WhisperFeatureExtractor):
|
40 |
+
def _np_extract_fbank_features(self, waveform: np.array) -> np.ndarray:
|
41 |
+
"""
|
42 |
+
Compute the log-mel spectrogram of the provided audio using torch filters. Using the torch implementation
|
43 |
+
computes stft filter banks approx 5x faster than its numpy counterpart, which is the native implementation
|
44 |
+
in transformers, and matches to within 1e-5 abs tolerance.
|
45 |
+
"""
|
46 |
+
waveform = torch.from_numpy(waveform).type(torch.float32)
|
47 |
+
|
48 |
+
window = torch.hann_window(self.n_fft)
|
49 |
+
stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
|
50 |
+
magnitudes = stft[..., :-1].abs() ** 2
|
51 |
+
|
52 |
+
mel_filters = torch.from_numpy(self.mel_filters).type(torch.float32)
|
53 |
+
mel_spec = mel_filters.T @ magnitudes
|
54 |
+
|
55 |
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
56 |
+
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
57 |
+
log_spec = (log_spec + 4.0) / 4.0
|
58 |
+
return log_spec.numpy()
|
59 |
+
|
60 |
+
|
61 |
+
class FlaxWhisperPipeline:
|
62 |
+
def __init__(
|
63 |
+
self,
|
64 |
+
checkpoint="openai/whisper-large-v2",
|
65 |
+
dtype=jnp.float32,
|
66 |
+
batch_size=None,
|
67 |
+
max_length=None,
|
68 |
+
**kwargs,
|
69 |
+
):
|
70 |
+
"""
|
71 |
+
Args
|
72 |
+
checkpoint (`str`, *optional*, defaults to `"openai/whisper-large-v2"):
|
73 |
+
The Whisper checkpoint to use with the pipeline. Must be an available checkpoint on the Hugging Face Hub
|
74 |
+
with Flax weights.
|
75 |
+
dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
|
76 |
+
The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
|
77 |
+
`jax.numpy.bfloat16` (on TPUs). This can be used to enable half-precision inference on GPUs or TPUs.
|
78 |
+
If specified all the computation will be performed with the given `dtype`. **Note that this only
|
79 |
+
specifies the dtype of the computation and does not influence the dtype of model parameters.**
|
80 |
+
batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
|
81 |
+
The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
|
82 |
+
a batch size in the `__init__` method will be superseded by any batch size passed to the `__call__` method.
|
83 |
+
max_length (`int`, *optional*):
|
84 |
+
The maximum numbers of tokens to generate. Defaults to `model.config.max_length`.
|
85 |
+
"""
|
86 |
+
self.checkpoint = checkpoint
|
87 |
+
self.dtype = dtype
|
88 |
+
|
89 |
+
self.feature_extractor = FlaxWhisperFeatureExtractor.from_pretrained(self.checkpoint)
|
90 |
+
self.tokenizer = WhisperTokenizerFast.from_pretrained(self.checkpoint)
|
91 |
+
|
92 |
+
self.model, self.params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
93 |
+
self.checkpoint,
|
94 |
+
_do_init=False,
|
95 |
+
dtype=self.dtype,
|
96 |
+
**kwargs,
|
97 |
+
)
|
98 |
+
|
99 |
+
self.max_length = max_length if max_length is not None else self.model.generation_config.max_length
|
100 |
+
self.min_batch_size = jax.local_device_count()
|
101 |
+
self.batch_size = (
|
102 |
+
batch_size if batch_size is not None else self.min_batch_size
|
103 |
+
) # we need a minimum of 1 batch per-device
|
104 |
+
|
105 |
+
def generate(
|
106 |
+
params,
|
107 |
+
input_features,
|
108 |
+
forced_decoder_ids,
|
109 |
+
return_timestamps,
|
110 |
+
num_beams,
|
111 |
+
length_penalty,
|
112 |
+
do_sample,
|
113 |
+
top_k,
|
114 |
+
temperature,
|
115 |
+
):
|
116 |
+
output_ids = self.model.pipeline_generate(
|
117 |
+
input_features,
|
118 |
+
params=params,
|
119 |
+
forced_decoder_ids=forced_decoder_ids,
|
120 |
+
return_timestamps=return_timestamps,
|
121 |
+
max_length=self.max_length,
|
122 |
+
num_beams=num_beams,
|
123 |
+
length_penalty=length_penalty,
|
124 |
+
do_sample=do_sample,
|
125 |
+
top_k=top_k,
|
126 |
+
temperature=temperature,
|
127 |
+
)
|
128 |
+
return output_ids
|
129 |
+
|
130 |
+
self.params = jax_utils.replicate(self.params)
|
131 |
+
self.p_generate = jax.pmap(
|
132 |
+
generate,
|
133 |
+
"input_features",
|
134 |
+
in_axes=(0, 0, None, None, None, None, None, None, None),
|
135 |
+
static_broadcasted_argnums=(
|
136 |
+
3,
|
137 |
+
4,
|
138 |
+
5,
|
139 |
+
6,
|
140 |
+
7,
|
141 |
+
8,
|
142 |
+
),
|
143 |
+
)
|
144 |
+
|
145 |
+
def generate(
|
146 |
+
self,
|
147 |
+
input_features,
|
148 |
+
language=None,
|
149 |
+
task=None,
|
150 |
+
return_timestamps=False,
|
151 |
+
num_beams=1,
|
152 |
+
length_penalty=1.0,
|
153 |
+
do_sample=False,
|
154 |
+
top_k=50,
|
155 |
+
temperature=1.0,
|
156 |
+
):
|
157 |
+
forced_decoder_ids = self.get_forced_decoder_ids(
|
158 |
+
language=language, task=task, return_timestamps=return_timestamps
|
159 |
+
)
|
160 |
+
# if we're using pmap we need to manually replicate the input data across devices and gather the output tokens
|
161 |
+
output_ids = self.p_generate(
|
162 |
+
freeze(self.params),
|
163 |
+
shard(input_features),
|
164 |
+
forced_decoder_ids,
|
165 |
+
return_timestamps,
|
166 |
+
num_beams,
|
167 |
+
length_penalty,
|
168 |
+
do_sample,
|
169 |
+
top_k,
|
170 |
+
temperature,
|
171 |
+
).sequences
|
172 |
+
output_ids = jax.device_get(output_ids.reshape(-1, self.max_length))
|
173 |
+
return output_ids
|
174 |
+
|
175 |
+
def get_forced_decoder_ids(self, generation_config=None, task=None, language=None, return_timestamps=False):
|
176 |
+
if generation_config is None:
|
177 |
+
generation_config = self.model.generation_config
|
178 |
+
|
179 |
+
if hasattr(generation_config, "is_multilingual"):
|
180 |
+
is_multilingual = generation_config.is_multilingual
|
181 |
+
else:
|
182 |
+
is_multilingual = None
|
183 |
+
|
184 |
+
forced_decoder_ids = []
|
185 |
+
|
186 |
+
if is_multilingual:
|
187 |
+
if language is not None:
|
188 |
+
language = language.lower()
|
189 |
+
if language in generation_config.lang_to_id.keys():
|
190 |
+
language_token = language
|
191 |
+
elif language in TO_LANGUAGE_CODE.values():
|
192 |
+
language_token = f"<|{language}|>"
|
193 |
+
elif language in TO_LANGUAGE_CODE.keys():
|
194 |
+
language_token = f"<|{TO_LANGUAGE_CODE[language]}|>"
|
195 |
+
else:
|
196 |
+
if len(language) == 2:
|
197 |
+
# ISO 639-1 language code
|
198 |
+
acceptable_languages = list(TO_LANGUAGE_CODE.values())
|
199 |
+
elif "<" in language or "|" in language or ">" in language:
|
200 |
+
# generation config language code
|
201 |
+
acceptable_languages = list(generation_config.lang_to_id.keys())
|
202 |
+
else:
|
203 |
+
# language passed as a string
|
204 |
+
acceptable_languages = list(TO_LANGUAGE_CODE.keys())
|
205 |
+
raise ValueError(
|
206 |
+
f"Unsupported language: {language}. Language should be one of:" f" {acceptable_languages}."
|
207 |
+
)
|
208 |
+
forced_decoder_ids.append((1, generation_config.lang_to_id[language_token]))
|
209 |
+
|
210 |
+
if task is not None:
|
211 |
+
forced_decoder_ids.append((2, generation_config.task_to_id[task]))
|
212 |
+
else:
|
213 |
+
forced_decoder_ids.append((2, generation_config.task_to_id["transcribe"]))
|
214 |
+
|
215 |
+
if not return_timestamps:
|
216 |
+
if forced_decoder_ids and forced_decoder_ids[-1][0] != generation_config.no_timestamps_token_id:
|
217 |
+
idx = forced_decoder_ids[-1][0] + 1 if forced_decoder_ids else 1
|
218 |
+
forced_decoder_ids.append((idx, generation_config.no_timestamps_token_id))
|
219 |
+
|
220 |
+
return forced_decoder_ids
|
221 |
+
|
222 |
+
def chunk_iter_with_batch(self, inputs, chunk_len, stride_left, stride_right, batch_size):
|
223 |
+
inputs_len = inputs.shape[0]
|
224 |
+
step = chunk_len - stride_left - stride_right
|
225 |
+
|
226 |
+
all_chunk_start_idx = np.arange(0, inputs_len, step)
|
227 |
+
num_samples = len(all_chunk_start_idx)
|
228 |
+
|
229 |
+
num_batches = math.ceil(num_samples / batch_size)
|
230 |
+
batch_idx = np.array_split(np.arange(num_samples), num_batches)
|
231 |
+
|
232 |
+
for idx in batch_idx:
|
233 |
+
chunk_start_idx = all_chunk_start_idx[idx]
|
234 |
+
|
235 |
+
chunk_end_idx = chunk_start_idx + chunk_len
|
236 |
+
|
237 |
+
chunks = [inputs[chunk_start:chunk_end] for chunk_start, chunk_end in zip(chunk_start_idx, chunk_end_idx)]
|
238 |
+
processed = self.feature_extractor(
|
239 |
+
chunks, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
240 |
+
)
|
241 |
+
|
242 |
+
_stride_left = np.where(chunk_start_idx == 0, 0, stride_left)
|
243 |
+
is_last = np.where(stride_right > 0, chunk_end_idx > inputs_len, chunk_end_idx >= inputs_len)
|
244 |
+
_stride_right = np.where(is_last, 0, stride_right)
|
245 |
+
|
246 |
+
chunk_lens = [chunk.shape[0] for chunk in chunks]
|
247 |
+
strides = [
|
248 |
+
(chunk_l, _stride_l, _stride_r)
|
249 |
+
for chunk_l, _stride_l, _stride_r in zip(chunk_lens, _stride_left, _stride_right)
|
250 |
+
]
|
251 |
+
|
252 |
+
yield {"stride": strides, **processed}
|
253 |
+
|
254 |
+
def preprocess_batch(self, inputs, chunk_length_s=30.0, stride_length_s=None, batch_size=None):
|
255 |
+
if isinstance(inputs, np.ndarray):
|
256 |
+
logger.warning(
|
257 |
+
"Numpy array passed as input - no sampling rate checks will be performed."
|
258 |
+
"It is strongly recommended to pass the input as a dictionary with an 'array' key "
|
259 |
+
"containing the numpy array representing the audio, and a 'sampling_rate' key "
|
260 |
+
"containing the sampling rate associated with the audio array."
|
261 |
+
"Failing to do so can result in silent errors that might be hard to debug."
|
262 |
+
)
|
263 |
+
|
264 |
+
if isinstance(inputs, str):
|
265 |
+
if inputs.startswith("http://") or inputs.startswith("https://"):
|
266 |
+
# We need to actually check for a real protocol, otherwise it's impossible to use a local file
|
267 |
+
# like http_huggingface_co.png
|
268 |
+
inputs = requests.get(inputs).content
|
269 |
+
else:
|
270 |
+
with open(inputs, "rb") as f:
|
271 |
+
inputs = f.read()
|
272 |
+
|
273 |
+
if isinstance(inputs, bytes):
|
274 |
+
inputs = ffmpeg_read(inputs, self.feature_extractor.sampling_rate)
|
275 |
+
|
276 |
+
stride = None
|
277 |
+
if isinstance(inputs, dict):
|
278 |
+
stride = inputs.get("stride", None)
|
279 |
+
# Accepting `"array"` which is the key defined in `datasets` for
|
280 |
+
# better integration
|
281 |
+
if not ("sampling_rate" in inputs and "array" in inputs):
|
282 |
+
raise ValueError(
|
283 |
+
"When passing a dictionary to FlaxWhisperPipline, the dict needs to contain an 'array' key "
|
284 |
+
"containing the numpy array representing the audio, and a 'sampling_rate' key "
|
285 |
+
"containing the sampling rate associated with the audio array."
|
286 |
+
)
|
287 |
+
|
288 |
+
in_sampling_rate = inputs.get("sampling_rate")
|
289 |
+
inputs = inputs.get("array", None)
|
290 |
+
|
291 |
+
if in_sampling_rate != self.feature_extractor.sampling_rate:
|
292 |
+
try:
|
293 |
+
import librosa
|
294 |
+
except ImportError as err:
|
295 |
+
raise ImportError(
|
296 |
+
"To support resampling audio files, please install 'librosa' and 'soundfile'."
|
297 |
+
) from err
|
298 |
+
|
299 |
+
inputs = librosa.resample(
|
300 |
+
inputs, orig_sr=in_sampling_rate, target_sr=self.feature_extractor.sampling_rate
|
301 |
+
)
|
302 |
+
ratio = self.feature_extractor.sampling_rate / in_sampling_rate
|
303 |
+
else:
|
304 |
+
ratio = 1
|
305 |
+
|
306 |
+
if not isinstance(inputs, np.ndarray):
|
307 |
+
raise ValueError(f"We expect a numpy ndarray as input, got `{type(inputs)}`")
|
308 |
+
if len(inputs.shape) != 1:
|
309 |
+
raise ValueError("We expect a single channel audio input for AutomaticSpeechRecognitionPipeline")
|
310 |
+
|
311 |
+
if stride is not None:
|
312 |
+
if stride[0] + stride[1] > inputs.shape[0]:
|
313 |
+
raise ValueError("Stride is too large for input")
|
314 |
+
|
315 |
+
# Stride needs to get the chunk length here, it's going to get
|
316 |
+
# swallowed by the `feature_extractor` later, and then batching
|
317 |
+
# can add extra data in the inputs, so we need to keep track
|
318 |
+
# of the original length in the stride so we can cut properly.
|
319 |
+
stride = (inputs.shape[0], int(round(stride[0] * ratio)), int(round(stride[1] * ratio)))
|
320 |
+
|
321 |
+
if chunk_length_s:
|
322 |
+
if stride_length_s is None:
|
323 |
+
stride_length_s = chunk_length_s / 6
|
324 |
+
|
325 |
+
if isinstance(stride_length_s, (int, float)):
|
326 |
+
stride_length_s = [stride_length_s, stride_length_s]
|
327 |
+
|
328 |
+
chunk_len = round(chunk_length_s * self.feature_extractor.sampling_rate)
|
329 |
+
stride_left = round(stride_length_s[0] * self.feature_extractor.sampling_rate)
|
330 |
+
stride_right = round(stride_length_s[1] * self.feature_extractor.sampling_rate)
|
331 |
+
|
332 |
+
if chunk_len < stride_left + stride_right:
|
333 |
+
raise ValueError("Chunk length must be superior to stride length")
|
334 |
+
|
335 |
+
for item in self.chunk_iter_with_batch(
|
336 |
+
inputs,
|
337 |
+
chunk_len,
|
338 |
+
stride_left,
|
339 |
+
stride_right,
|
340 |
+
batch_size,
|
341 |
+
):
|
342 |
+
yield item
|
343 |
+
else:
|
344 |
+
processed = self.feature_extractor(
|
345 |
+
inputs, sampling_rate=self.feature_extractor.sampling_rate, return_tensors="np"
|
346 |
+
)
|
347 |
+
if stride is not None:
|
348 |
+
processed["stride"] = stride
|
349 |
+
yield processed
|
350 |
+
|
351 |
+
def postprocess(self, model_outputs, return_timestamps=None, return_language=None):
|
352 |
+
# unpack the outputs from list(dict(list)) to list(dict)
|
353 |
+
model_outputs = [dict(zip(output, t)) for output in model_outputs for t in zip(*output.values())]
|
354 |
+
|
355 |
+
time_precision = self.feature_extractor.chunk_length / self.model.config.max_source_positions
|
356 |
+
# Send the chunking back to seconds, it's easier to handle in whisper
|
357 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
358 |
+
for output in model_outputs:
|
359 |
+
if "stride" in output:
|
360 |
+
chunk_len, stride_left, stride_right = output["stride"]
|
361 |
+
# Go back in seconds
|
362 |
+
chunk_len /= sampling_rate
|
363 |
+
stride_left /= sampling_rate
|
364 |
+
stride_right /= sampling_rate
|
365 |
+
output["stride"] = chunk_len, stride_left, stride_right
|
366 |
+
|
367 |
+
text, optional = self.tokenizer._decode_asr(
|
368 |
+
model_outputs,
|
369 |
+
return_timestamps=return_timestamps,
|
370 |
+
return_language=return_language,
|
371 |
+
time_precision=time_precision,
|
372 |
+
)
|
373 |
+
return {"text": text, **optional}
|
374 |
+
|
375 |
+
def forward(
|
376 |
+
self,
|
377 |
+
model_inputs,
|
378 |
+
batch_size=None,
|
379 |
+
language=None,
|
380 |
+
task=None,
|
381 |
+
return_timestamps=False,
|
382 |
+
num_beams=1,
|
383 |
+
length_penalty=1.0,
|
384 |
+
do_sample=False,
|
385 |
+
top_k=50,
|
386 |
+
temperature=1.0,
|
387 |
+
):
|
388 |
+
# We need to keep track of some additional input arguments for post-processing so need to forward these on after running generation
|
389 |
+
input_features = model_inputs.pop("input_features")
|
390 |
+
input_batch_size = input_features.shape[0]
|
391 |
+
|
392 |
+
if input_batch_size != batch_size:
|
393 |
+
padding = np.zeros([batch_size - input_batch_size, *input_features.shape[1:]], input_features.dtype)
|
394 |
+
input_features = np.concatenate([input_features, padding])
|
395 |
+
|
396 |
+
pred_ids = self.generate(
|
397 |
+
input_features,
|
398 |
+
language=language,
|
399 |
+
task=task,
|
400 |
+
return_timestamps=return_timestamps,
|
401 |
+
num_beams=num_beams,
|
402 |
+
length_penalty=length_penalty,
|
403 |
+
do_sample=do_sample,
|
404 |
+
top_k=top_k,
|
405 |
+
temperature=temperature,
|
406 |
+
)[:input_batch_size]
|
407 |
+
|
408 |
+
# tokenizer's decode method expects an extra dim - we insert it here for convenience
|
409 |
+
out = {"tokens": pred_ids[:, None, :]}
|
410 |
+
|
411 |
+
stride = model_inputs.pop("stride", None)
|
412 |
+
if stride is not None:
|
413 |
+
out["stride"] = stride
|
414 |
+
|
415 |
+
return out
|
416 |
+
|
417 |
+
def __call__(
|
418 |
+
self,
|
419 |
+
inputs,
|
420 |
+
chunk_length_s=30.0,
|
421 |
+
stride_length_s=None,
|
422 |
+
batch_size=None,
|
423 |
+
language=None,
|
424 |
+
task=None,
|
425 |
+
return_timestamps=None,
|
426 |
+
num_beams=1,
|
427 |
+
length_penalty=1.0,
|
428 |
+
do_sample=False,
|
429 |
+
top_k=50,
|
430 |
+
temperature=1.0,
|
431 |
+
):
|
432 |
+
"""
|
433 |
+
Transcribe an audio input sequence to a text transcription, optionally with timestamps.
|
434 |
+
|
435 |
+
Args:
|
436 |
+
inputs (`np.ndarray` or `bytes` or `str` or `dict`):
|
437 |
+
The inputs is either:
|
438 |
+
- `str` that is the filename of the audio file, the file will be read at the correct sampling rate
|
439 |
+
to get the waveform using *ffmpeg*. This requires *ffmpeg* to be installed on the system.
|
440 |
+
- `bytes` is the byte content of an audio file and is interpreted by *ffmpeg* in the
|
441 |
+
same way.
|
442 |
+
- (`np.ndarray` of shape (n, ) of type `np.float32` or `np.float64`)
|
443 |
+
Raw audio assumed to be at the correct sampling rate (16kHz). Note that no further sampling
|
444 |
+
rate check will be done.
|
445 |
+
- `dict` form can be used to pass raw audio sampled at arbitrary `sampling_rate` and let this
|
446 |
+
pipeline do the resampling. The dict must be in the format `{"sampling_rate": int, "array":
|
447 |
+
np.array}`. Optionally an additional argument `"stride": (left: int, right: int)` can be used to
|
448 |
+
ask the pipeline to treat the first `left` samples and last `right` samples to be ignored in
|
449 |
+
decoding (but used at inference to provide more context to the model). In general, this additional
|
450 |
+
stride argument is not required.
|
451 |
+
chunk_length_s (`float`, *optional*, defaults to 30.0):
|
452 |
+
The input length for each chunk. If `chunk_length_s = 0` then chunking is disabled. By default, the chunk
|
453 |
+
length is set 30.0s, equal to Whisper's context window.
|
454 |
+
stride_length_s (`float`, *optional*, defaults to `chunk_length_s / 6`):
|
455 |
+
The length of stride on the left and right of each chunk. Used only with `chunk_length_s > 0`. This enables
|
456 |
+
the model to *see* more context and infer letters better than without this context but the pipeline
|
457 |
+
discards the stride bits at the end to make the final reconstitution as perfect as possible.
|
458 |
+
|
459 |
+
<Tip>
|
460 |
+
|
461 |
+
For more information on how to effectively use `stride_length_s`, refer to the [ASR chunking
|
462 |
+
blog post](https://huggingface.co/blog/asr-chunking).
|
463 |
+
|
464 |
+
</Tip>
|
465 |
+
batch_size (`int`, *optional*, defaults to the minimum per-device batch size, i.e. `jax.local_device_count()`):
|
466 |
+
The batch size to be used in chunking transcription. Beneficial for transcribing long audio files. Passing
|
467 |
+
a batch size in the `__call__` method will supersede any batch size passed to the `__init__`.
|
468 |
+
task (`str`, *optional*):
|
469 |
+
Task to use for generation, either `"transcribe"` or `"translate"`. Defaults to `"transcribe"`.
|
470 |
+
language (`str`, *optional*):
|
471 |
+
Language token to use for generation, can be either in the form of `"<|en|>"`, `"en"` or `"english"`.
|
472 |
+
Defaults to `None`, meaning the language is automatically inferred from the audio input.
|
473 |
+
return_timestamps (*optional*, `bool`):
|
474 |
+
Whether to return timestamps in the prediction. Defaults to False. If set to true, the pipeline
|
475 |
+
will return two keys in the output dictionary: `"text"` containing the text transcription, and `"chunks"`
|
476 |
+
containing the transcription segments chunked by their utterance-level timestamps.
|
477 |
+
length_penalty (*optional*, `float`):
|
478 |
+
Exponential penalty to the length that is used with beam-based generation. It is applied as an
|
479 |
+
exponent to the sequence length, which in turn is used to divide the score of the sequence. Since
|
480 |
+
the score is the log likelihood of the sequence (i.e. negative), length_penalty > 1.0 promotes
|
481 |
+
longer sequences, while length_penalty < 1.0 encourages shorter sequences.
|
482 |
+
do_sample (*optional*, `bool`):
|
483 |
+
Whether or not to use sampling ; use greedy decoding otherwise.
|
484 |
+
top_k (*optional*, `int`):
|
485 |
+
The number of the highest probability vocabulary tokens to keep for top-k-filtering.
|
486 |
+
temperature (*optional*, `float`):
|
487 |
+
The value used to modulate the next token probabilities if sampling.
|
488 |
+
|
489 |
+
Return:
|
490 |
+
`Dict`: A dictionary with the following keys:
|
491 |
+
- **text** (`str` ) -- The recognised text.
|
492 |
+
- **chunks** (*optional(, `List[Dict]`)
|
493 |
+
When using `return_timestamps`, the `chunks` will become a list containing all the various text
|
494 |
+
chunks identified by the model, *e.g.* `[{"text": "hi ", "timestamps": (0.5,0.9), {"text":
|
495 |
+
"there", "timestamps": (1.0, 1.5)}]`. The original full text can roughly be recovered by doing
|
496 |
+
`"".join(chunk["text"] for chunk in output["chunks"])`.
|
497 |
+
"""
|
498 |
+
batch_size = batch_size if batch_size is not None else self.batch_size
|
499 |
+
if batch_size % self.min_batch_size != 0:
|
500 |
+
raise ValueError(
|
501 |
+
f"Batch size must be a multiple of the number of JAX devices, but got batch size {batch_size} and num devices {self.min_batch_size}."
|
502 |
+
)
|
503 |
+
|
504 |
+
dataloader = self.preprocess_batch(
|
505 |
+
inputs, chunk_length_s=chunk_length_s, stride_length_s=stride_length_s, batch_size=batch_size
|
506 |
+
)
|
507 |
+
model_outputs = []
|
508 |
+
# iterate over our chunked audio samples
|
509 |
+
for batch in dataloader:
|
510 |
+
model_outputs.append(
|
511 |
+
self.forward(
|
512 |
+
batch,
|
513 |
+
batch_size=batch_size,
|
514 |
+
language=language,
|
515 |
+
task=task,
|
516 |
+
return_timestamps=return_timestamps,
|
517 |
+
num_beams=num_beams,
|
518 |
+
length_penalty=length_penalty,
|
519 |
+
do_sample=do_sample,
|
520 |
+
top_k=top_k,
|
521 |
+
temperature=temperature,
|
522 |
+
)
|
523 |
+
)
|
524 |
+
post_processed = self.postprocess(model_outputs, return_timestamps=return_timestamps)
|
525 |
+
return post_processed
|
distil_whisper/train_state.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Mapping, MutableMapping, Optional, Tuple
|
2 |
+
|
3 |
+
import flax.core
|
4 |
+
import flax.serialization
|
5 |
+
import flax.struct
|
6 |
+
import jax.numpy as jnp
|
7 |
+
from flax import traverse_util
|
8 |
+
from flax.core import scope as flax_scope
|
9 |
+
from flax.linen import partitioning as flax_partitioning
|
10 |
+
|
11 |
+
|
12 |
+
EMPTY_DICT = flax.core.freeze({})
|
13 |
+
FrozenDict = flax_scope.FrozenDict
|
14 |
+
FrozenVariableDict = flax_scope.FrozenVariableDict
|
15 |
+
MutableVariableDict = flax_scope.MutableVariableDict
|
16 |
+
VariableDict = flax_scope.VariableDict
|
17 |
+
|
18 |
+
|
19 |
+
def _validate_params_axes(params_axes, params):
|
20 |
+
axis_names = flax_partitioning.get_axis_names(params_axes)
|
21 |
+
missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
|
22 |
+
traverse_util.flatten_dict(axis_names, sep="/")
|
23 |
+
)
|
24 |
+
if missing_params_axes:
|
25 |
+
raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")
|
26 |
+
|
27 |
+
|
28 |
+
def _split_variables_and_axes(
|
29 |
+
variables_and_axes: FrozenVariableDict,
|
30 |
+
) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
|
31 |
+
"""Splits `variables_and_axes` into two separate dicts with the same keys."""
|
32 |
+
# For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
|
33 |
+
variables = {}
|
34 |
+
axes = {}
|
35 |
+
for k, v in variables_and_axes.items():
|
36 |
+
if k.endswith("_axes"):
|
37 |
+
axes[k[:-5]] = v # k without "_axes".
|
38 |
+
_validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes".
|
39 |
+
else:
|
40 |
+
variables[k] = v
|
41 |
+
return flax.core.freeze(variables), flax.core.freeze(axes)
|
42 |
+
|
43 |
+
|
44 |
+
class InferenceState(flax.struct.PyTreeNode):
|
45 |
+
"""State compatible with FlaxOptimTrainState without optimizer state."""
|
46 |
+
|
47 |
+
step: jnp.ndarray
|
48 |
+
params: flax_scope.FrozenVariableDict
|
49 |
+
params_axes: Optional[flax_scope.FrozenVariableDict] = None
|
50 |
+
flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
|
51 |
+
flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
|
55 |
+
other_variables, params = model_variables.pop("params")
|
56 |
+
if "params_axes" in other_variables:
|
57 |
+
other_variables, params_axes = other_variables.pop("params_axes")
|
58 |
+
_validate_params_axes(params_axes, params)
|
59 |
+
else:
|
60 |
+
params_axes = None
|
61 |
+
|
62 |
+
# Split other_variables into mutables and their corresponding axes.
|
63 |
+
flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
|
64 |
+
flax_mutables_axes = flax_mutables_axes or None
|
65 |
+
return InferenceState(
|
66 |
+
step=jnp.array(0),
|
67 |
+
params=params,
|
68 |
+
params_axes=params_axes,
|
69 |
+
flax_mutables=flax_mutables,
|
70 |
+
flax_mutables_axes=flax_mutables_axes,
|
71 |
+
)
|
72 |
+
|
73 |
+
@property
|
74 |
+
def param_states(self) -> FrozenVariableDict:
|
75 |
+
"""The optimizer states of the parameters as a PyTree."""
|
76 |
+
raise NotImplementedError("InferenceState has no optimizer states.")
|
77 |
+
|
78 |
+
def apply_gradient(self, *args, **kwargs) -> "InferenceState":
|
79 |
+
raise NotImplementedError("InferenceState does not support `apply_gradient`.")
|
80 |
+
|
81 |
+
def state_dict(self) -> MutableMapping[str, Any]:
|
82 |
+
state_dict = {
|
83 |
+
"target": flax.core.unfreeze(self.params),
|
84 |
+
"state": {"step": self.step},
|
85 |
+
}
|
86 |
+
if self.flax_mutables:
|
87 |
+
state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
|
88 |
+
return state_dict
|
89 |
+
|
90 |
+
def replace_step(self, step: jnp.ndarray) -> "InferenceState":
|
91 |
+
return self.replace(step=step)
|
92 |
+
|
93 |
+
def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
|
94 |
+
return self.replace(params=params)
|
95 |
+
|
96 |
+
def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
|
97 |
+
return self.replace(flax_mutables=flax_mutables)
|
98 |
+
|
99 |
+
def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
|
100 |
+
return self.replace(
|
101 |
+
params=flax.core.freeze(state_dict["target"]),
|
102 |
+
step=state_dict["state"]["step"],
|
103 |
+
flax_mutables=(
|
104 |
+
flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT
|
105 |
+
),
|
106 |
+
)
|
107 |
+
|
108 |
+
def as_logical_axes(self) -> "InferenceState":
|
109 |
+
# Set step to None so that when the logical axes are processed by the
|
110 |
+
# flax.partitioning.logical_to_mesh_axes function, it will be skipped
|
111 |
+
# because jax.tree_map will short circut and never call the function on the
|
112 |
+
# step.
|
113 |
+
flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
|
114 |
+
return InferenceState(
|
115 |
+
step=None,
|
116 |
+
params=flax_partitioning.get_axis_names(self.params_axes),
|
117 |
+
flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
|
118 |
+
)
|
events.out.tfevents.1696323477.t1v-n-4eccb2d4-w-0.2889203.0.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e50d1c4acaf53861664c29cb9b845856a6c4b304c1d41c2c187d54bbfd22076a
|
3 |
+
size 51942
|
flax_model.msgpack
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:843c32ed2be9c335890c72e0b7d859323f97e3dae078a2ddc93402300d3272a7
|
3 |
+
size 6173221863
|
generation_config.json
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"alignment_heads": [
|
3 |
+
[
|
4 |
+
10,
|
5 |
+
12
|
6 |
+
],
|
7 |
+
[
|
8 |
+
13,
|
9 |
+
17
|
10 |
+
],
|
11 |
+
[
|
12 |
+
16,
|
13 |
+
11
|
14 |
+
],
|
15 |
+
[
|
16 |
+
16,
|
17 |
+
12
|
18 |
+
],
|
19 |
+
[
|
20 |
+
16,
|
21 |
+
13
|
22 |
+
],
|
23 |
+
[
|
24 |
+
17,
|
25 |
+
15
|
26 |
+
],
|
27 |
+
[
|
28 |
+
17,
|
29 |
+
16
|
30 |
+
],
|
31 |
+
[
|
32 |
+
18,
|
33 |
+
4
|
34 |
+
],
|
35 |
+
[
|
36 |
+
18,
|
37 |
+
11
|
38 |
+
],
|
39 |
+
[
|
40 |
+
18,
|
41 |
+
19
|
42 |
+
],
|
43 |
+
[
|
44 |
+
19,
|
45 |
+
11
|
46 |
+
],
|
47 |
+
[
|
48 |
+
21,
|
49 |
+
2
|
50 |
+
],
|
51 |
+
[
|
52 |
+
21,
|
53 |
+
3
|
54 |
+
],
|
55 |
+
[
|
56 |
+
22,
|
57 |
+
3
|
58 |
+
],
|
59 |
+
[
|
60 |
+
22,
|
61 |
+
9
|
62 |
+
],
|
63 |
+
[
|
64 |
+
22,
|
65 |
+
12
|
66 |
+
],
|
67 |
+
[
|
68 |
+
23,
|
69 |
+
5
|
70 |
+
],
|
71 |
+
[
|
72 |
+
23,
|
73 |
+
7
|
74 |
+
],
|
75 |
+
[
|
76 |
+
23,
|
77 |
+
13
|
78 |
+
],
|
79 |
+
[
|
80 |
+
25,
|
81 |
+
5
|
82 |
+
],
|
83 |
+
[
|
84 |
+
26,
|
85 |
+
1
|
86 |
+
],
|
87 |
+
[
|
88 |
+
26,
|
89 |
+
12
|
90 |
+
],
|
91 |
+
[
|
92 |
+
27,
|
93 |
+
15
|
94 |
+
]
|
95 |
+
],
|
96 |
+
"begin_suppress_tokens": [
|
97 |
+
220,
|
98 |
+
50257
|
99 |
+
],
|
100 |
+
"bos_token_id": 50257,
|
101 |
+
"decoder_start_token_id": 50258,
|
102 |
+
"eos_token_id": 50257,
|
103 |
+
"forced_decoder_ids": [
|
104 |
+
[
|
105 |
+
1,
|
106 |
+
null
|
107 |
+
],
|
108 |
+
[
|
109 |
+
2,
|
110 |
+
50359
|
111 |
+
],
|
112 |
+
[
|
113 |
+
3,
|
114 |
+
50363
|
115 |
+
]
|
116 |
+
],
|
117 |
+
"is_multilingual": true,
|
118 |
+
"lang_to_id": {
|
119 |
+
"<|af|>": 50327,
|
120 |
+
"<|am|>": 50334,
|
121 |
+
"<|ar|>": 50272,
|
122 |
+
"<|as|>": 50350,
|
123 |
+
"<|az|>": 50304,
|
124 |
+
"<|ba|>": 50355,
|
125 |
+
"<|be|>": 50330,
|
126 |
+
"<|bg|>": 50292,
|
127 |
+
"<|bn|>": 50302,
|
128 |
+
"<|bo|>": 50347,
|
129 |
+
"<|br|>": 50309,
|
130 |
+
"<|bs|>": 50315,
|
131 |
+
"<|ca|>": 50270,
|
132 |
+
"<|cs|>": 50283,
|
133 |
+
"<|cy|>": 50297,
|
134 |
+
"<|da|>": 50285,
|
135 |
+
"<|de|>": 50261,
|
136 |
+
"<|el|>": 50281,
|
137 |
+
"<|en|>": 50259,
|
138 |
+
"<|es|>": 50262,
|
139 |
+
"<|et|>": 50307,
|
140 |
+
"<|eu|>": 50310,
|
141 |
+
"<|fa|>": 50300,
|
142 |
+
"<|fi|>": 50277,
|
143 |
+
"<|fo|>": 50338,
|
144 |
+
"<|fr|>": 50265,
|
145 |
+
"<|gl|>": 50319,
|
146 |
+
"<|gu|>": 50333,
|
147 |
+
"<|haw|>": 50352,
|
148 |
+
"<|ha|>": 50354,
|
149 |
+
"<|he|>": 50279,
|
150 |
+
"<|hi|>": 50276,
|
151 |
+
"<|hr|>": 50291,
|
152 |
+
"<|ht|>": 50339,
|
153 |
+
"<|hu|>": 50286,
|
154 |
+
"<|hy|>": 50312,
|
155 |
+
"<|id|>": 50275,
|
156 |
+
"<|is|>": 50311,
|
157 |
+
"<|it|>": 50274,
|
158 |
+
"<|ja|>": 50266,
|
159 |
+
"<|jw|>": 50356,
|
160 |
+
"<|ka|>": 50329,
|
161 |
+
"<|kk|>": 50316,
|
162 |
+
"<|km|>": 50323,
|
163 |
+
"<|kn|>": 50306,
|
164 |
+
"<|ko|>": 50264,
|
165 |
+
"<|la|>": 50294,
|
166 |
+
"<|lb|>": 50345,
|
167 |
+
"<|ln|>": 50353,
|
168 |
+
"<|lo|>": 50336,
|
169 |
+
"<|lt|>": 50293,
|
170 |
+
"<|lv|>": 50301,
|
171 |
+
"<|mg|>": 50349,
|
172 |
+
"<|mi|>": 50295,
|
173 |
+
"<|mk|>": 50308,
|
174 |
+
"<|ml|>": 50296,
|
175 |
+
"<|mn|>": 50314,
|
176 |
+
"<|mr|>": 50320,
|
177 |
+
"<|ms|>": 50282,
|
178 |
+
"<|mt|>": 50343,
|
179 |
+
"<|my|>": 50346,
|
180 |
+
"<|ne|>": 50313,
|
181 |
+
"<|nl|>": 50271,
|
182 |
+
"<|nn|>": 50342,
|
183 |
+
"<|no|>": 50288,
|
184 |
+
"<|oc|>": 50328,
|
185 |
+
"<|pa|>": 50321,
|
186 |
+
"<|pl|>": 50269,
|
187 |
+
"<|ps|>": 50340,
|
188 |
+
"<|pt|>": 50267,
|
189 |
+
"<|ro|>": 50284,
|
190 |
+
"<|ru|>": 50263,
|
191 |
+
"<|sa|>": 50344,
|
192 |
+
"<|sd|>": 50332,
|
193 |
+
"<|si|>": 50322,
|
194 |
+
"<|sk|>": 50298,
|
195 |
+
"<|sl|>": 50305,
|
196 |
+
"<|sn|>": 50324,
|
197 |
+
"<|so|>": 50326,
|
198 |
+
"<|sq|>": 50317,
|
199 |
+
"<|sr|>": 50303,
|
200 |
+
"<|su|>": 50357,
|
201 |
+
"<|sv|>": 50273,
|
202 |
+
"<|sw|>": 50318,
|
203 |
+
"<|ta|>": 50287,
|
204 |
+
"<|te|>": 50299,
|
205 |
+
"<|tg|>": 50331,
|
206 |
+
"<|th|>": 50289,
|
207 |
+
"<|tk|>": 50341,
|
208 |
+
"<|tl|>": 50348,
|
209 |
+
"<|tr|>": 50268,
|
210 |
+
"<|tt|>": 50351,
|
211 |
+
"<|uk|>": 50280,
|
212 |
+
"<|ur|>": 50290,
|
213 |
+
"<|uz|>": 50337,
|
214 |
+
"<|vi|>": 50278,
|
215 |
+
"<|yi|>": 50335,
|
216 |
+
"<|yo|>": 50325,
|
217 |
+
"<|zh|>": 50260
|
218 |
+
},
|
219 |
+
"max_initial_timestamp_index": 1,
|
220 |
+
"max_length": 448,
|
221 |
+
"no_timestamps_token_id": 50363,
|
222 |
+
"pad_token_id": 50257,
|
223 |
+
"return_timestamps": false,
|
224 |
+
"suppress_tokens": [
|
225 |
+
1,
|
226 |
+
2,
|
227 |
+
7,
|
228 |
+
8,
|
229 |
+
9,
|
230 |
+
10,
|
231 |
+
14,
|
232 |
+
25,
|
233 |
+
26,
|
234 |
+
27,
|
235 |
+
28,
|
236 |
+
29,
|
237 |
+
31,
|
238 |
+
58,
|
239 |
+
59,
|
240 |
+
60,
|
241 |
+
61,
|
242 |
+
62,
|
243 |
+
63,
|
244 |
+
90,
|
245 |
+
91,
|
246 |
+
92,
|
247 |
+
93,
|
248 |
+
359,
|
249 |
+
503,
|
250 |
+
522,
|
251 |
+
542,
|
252 |
+
873,
|
253 |
+
893,
|
254 |
+
902,
|
255 |
+
918,
|
256 |
+
922,
|
257 |
+
931,
|
258 |
+
1350,
|
259 |
+
1853,
|
260 |
+
1982,
|
261 |
+
2460,
|
262 |
+
2627,
|
263 |
+
3246,
|
264 |
+
3253,
|
265 |
+
3268,
|
266 |
+
3536,
|
267 |
+
3846,
|
268 |
+
3961,
|
269 |
+
4183,
|
270 |
+
4667,
|
271 |
+
6585,
|
272 |
+
6647,
|
273 |
+
7273,
|
274 |
+
9061,
|
275 |
+
9383,
|
276 |
+
10428,
|
277 |
+
10929,
|
278 |
+
11938,
|
279 |
+
12033,
|
280 |
+
12331,
|
281 |
+
12562,
|
282 |
+
13793,
|
283 |
+
14157,
|
284 |
+
14635,
|
285 |
+
15265,
|
286 |
+
15618,
|
287 |
+
16553,
|
288 |
+
16604,
|
289 |
+
18362,
|
290 |
+
18956,
|
291 |
+
20075,
|
292 |
+
21675,
|
293 |
+
22520,
|
294 |
+
26130,
|
295 |
+
26161,
|
296 |
+
26435,
|
297 |
+
28279,
|
298 |
+
29464,
|
299 |
+
31650,
|
300 |
+
32302,
|
301 |
+
32470,
|
302 |
+
36865,
|
303 |
+
42863,
|
304 |
+
47425,
|
305 |
+
49870,
|
306 |
+
50254,
|
307 |
+
50258,
|
308 |
+
50358,
|
309 |
+
50359,
|
310 |
+
50360,
|
311 |
+
50361,
|
312 |
+
50362
|
313 |
+
],
|
314 |
+
"task_to_id": {
|
315 |
+
"transcribe": 50359,
|
316 |
+
"translate": 50358
|
317 |
+
},
|
318 |
+
"transformers_version": "4.34.0.dev0"
|
319 |
+
}
|
merges.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
normalizer.json
ADDED
@@ -0,0 +1,1742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"accessorise": "accessorize",
|
3 |
+
"accessorised": "accessorized",
|
4 |
+
"accessorises": "accessorizes",
|
5 |
+
"accessorising": "accessorizing",
|
6 |
+
"acclimatisation": "acclimatization",
|
7 |
+
"acclimatise": "acclimatize",
|
8 |
+
"acclimatised": "acclimatized",
|
9 |
+
"acclimatises": "acclimatizes",
|
10 |
+
"acclimatising": "acclimatizing",
|
11 |
+
"accoutrements": "accouterments",
|
12 |
+
"aeon": "eon",
|
13 |
+
"aeons": "eons",
|
14 |
+
"aerogramme": "aerogram",
|
15 |
+
"aerogrammes": "aerograms",
|
16 |
+
"aeroplane": "airplane",
|
17 |
+
"aeroplanes": "airplanes",
|
18 |
+
"aesthete": "esthete",
|
19 |
+
"aesthetes": "esthetes",
|
20 |
+
"aesthetic": "esthetic",
|
21 |
+
"aesthetically": "esthetically",
|
22 |
+
"aesthetics": "esthetics",
|
23 |
+
"aetiology": "etiology",
|
24 |
+
"ageing": "aging",
|
25 |
+
"aggrandisement": "aggrandizement",
|
26 |
+
"agonise": "agonize",
|
27 |
+
"agonised": "agonized",
|
28 |
+
"agonises": "agonizes",
|
29 |
+
"agonising": "agonizing",
|
30 |
+
"agonisingly": "agonizingly",
|
31 |
+
"almanack": "almanac",
|
32 |
+
"almanacks": "almanacs",
|
33 |
+
"aluminium": "aluminum",
|
34 |
+
"amortisable": "amortizable",
|
35 |
+
"amortisation": "amortization",
|
36 |
+
"amortisations": "amortizations",
|
37 |
+
"amortise": "amortize",
|
38 |
+
"amortised": "amortized",
|
39 |
+
"amortises": "amortizes",
|
40 |
+
"amortising": "amortizing",
|
41 |
+
"amphitheatre": "amphitheater",
|
42 |
+
"amphitheatres": "amphitheaters",
|
43 |
+
"anaemia": "anemia",
|
44 |
+
"anaemic": "anemic",
|
45 |
+
"anaesthesia": "anesthesia",
|
46 |
+
"anaesthetic": "anesthetic",
|
47 |
+
"anaesthetics": "anesthetics",
|
48 |
+
"anaesthetise": "anesthetize",
|
49 |
+
"anaesthetised": "anesthetized",
|
50 |
+
"anaesthetises": "anesthetizes",
|
51 |
+
"anaesthetising": "anesthetizing",
|
52 |
+
"anaesthetist": "anesthetist",
|
53 |
+
"anaesthetists": "anesthetists",
|
54 |
+
"anaesthetize": "anesthetize",
|
55 |
+
"anaesthetized": "anesthetized",
|
56 |
+
"anaesthetizes": "anesthetizes",
|
57 |
+
"anaesthetizing": "anesthetizing",
|
58 |
+
"analogue": "analog",
|
59 |
+
"analogues": "analogs",
|
60 |
+
"analyse": "analyze",
|
61 |
+
"analysed": "analyzed",
|
62 |
+
"analyses": "analyzes",
|
63 |
+
"analysing": "analyzing",
|
64 |
+
"anglicise": "anglicize",
|
65 |
+
"anglicised": "anglicized",
|
66 |
+
"anglicises": "anglicizes",
|
67 |
+
"anglicising": "anglicizing",
|
68 |
+
"annualised": "annualized",
|
69 |
+
"antagonise": "antagonize",
|
70 |
+
"antagonised": "antagonized",
|
71 |
+
"antagonises": "antagonizes",
|
72 |
+
"antagonising": "antagonizing",
|
73 |
+
"apologise": "apologize",
|
74 |
+
"apologised": "apologized",
|
75 |
+
"apologises": "apologizes",
|
76 |
+
"apologising": "apologizing",
|
77 |
+
"appal": "appall",
|
78 |
+
"appals": "appalls",
|
79 |
+
"appetiser": "appetizer",
|
80 |
+
"appetisers": "appetizers",
|
81 |
+
"appetising": "appetizing",
|
82 |
+
"appetisingly": "appetizingly",
|
83 |
+
"arbour": "arbor",
|
84 |
+
"arbours": "arbors",
|
85 |
+
"archaeologically": "archeologically",
|
86 |
+
"archaeologist": "archeologist",
|
87 |
+
"archaeologists": "archeologists",
|
88 |
+
"archaeology": "archeology</span>",
|
89 |
+
"archeological": "archaeological",
|
90 |
+
"ardour": "ardor",
|
91 |
+
"armour": "armor",
|
92 |
+
"armoured": "armored",
|
93 |
+
"armourer": "armorer",
|
94 |
+
"armourers": "armorers",
|
95 |
+
"armouries": "armories",
|
96 |
+
"armoury": "armory",
|
97 |
+
"artefact": "artifact",
|
98 |
+
"artefacts": "artifacts",
|
99 |
+
"authorise": "authorize",
|
100 |
+
"authorised": "authorized",
|
101 |
+
"authorises": "authorizes",
|
102 |
+
"authorising": "authorizing",
|
103 |
+
"axe": "ax",
|
104 |
+
"backpedalled": "backpedaled",
|
105 |
+
"backpedalling": "backpedaling",
|
106 |
+
"bannister": "banister",
|
107 |
+
"bannisters": "banisters",
|
108 |
+
"baptise": "baptize",
|
109 |
+
"baptised": "baptized",
|
110 |
+
"baptises": "baptizes",
|
111 |
+
"baptising": "baptizing",
|
112 |
+
"bastardise": "bastardize",
|
113 |
+
"bastardised": "bastardized",
|
114 |
+
"bastardises": "bastardizes",
|
115 |
+
"bastardising": "bastardizing",
|
116 |
+
"battleax": "battleaxe",
|
117 |
+
"baulk": "balk",
|
118 |
+
"baulked": "balked",
|
119 |
+
"baulking": "balking",
|
120 |
+
"baulks": "balks",
|
121 |
+
"bedevilled": "bedeviled",
|
122 |
+
"bedevilling": "bedeviling",
|
123 |
+
"behaviour": "behavior",
|
124 |
+
"behavioural": "behavioral",
|
125 |
+
"behaviourism": "behaviorism",
|
126 |
+
"behaviourist": "behaviorist",
|
127 |
+
"behaviourists": "behaviorists",
|
128 |
+
"behaviours": "behaviors",
|
129 |
+
"behove": "behoove",
|
130 |
+
"behoved": "behooved",
|
131 |
+
"behoves": "behooves",
|
132 |
+
"bejewelled": "bejeweled",
|
133 |
+
"belabour": "belabor",
|
134 |
+
"belaboured": "belabored",
|
135 |
+
"belabouring": "belaboring",
|
136 |
+
"belabours": "belabors",
|
137 |
+
"bevelled": "beveled",
|
138 |
+
"bevvies": "bevies",
|
139 |
+
"bevvy": "bevy",
|
140 |
+
"biassed": "biased",
|
141 |
+
"biassing": "biasing",
|
142 |
+
"bingeing": "binging",
|
143 |
+
"bougainvillaea": "bougainvillea",
|
144 |
+
"bougainvillaeas": "bougainvilleas",
|
145 |
+
"bowdlerise": "bowdlerize",
|
146 |
+
"bowdlerised": "bowdlerized",
|
147 |
+
"bowdlerises": "bowdlerizes",
|
148 |
+
"bowdlerising": "bowdlerizing",
|
149 |
+
"breathalyse": "breathalyze",
|
150 |
+
"breathalysed": "breathalyzed",
|
151 |
+
"breathalyser": "breathalyzer",
|
152 |
+
"breathalysers": "breathalyzers",
|
153 |
+
"breathalyses": "breathalyzes",
|
154 |
+
"breathalysing": "breathalyzing",
|
155 |
+
"brutalise": "brutalize",
|
156 |
+
"brutalised": "brutalized",
|
157 |
+
"brutalises": "brutalizes",
|
158 |
+
"brutalising": "brutalizing",
|
159 |
+
"busses": "buses",
|
160 |
+
"bussing": "busing",
|
161 |
+
"caesarean": "cesarean",
|
162 |
+
"caesareans": "cesareans",
|
163 |
+
"calibre": "caliber",
|
164 |
+
"calibres": "calibers",
|
165 |
+
"calliper": "caliper",
|
166 |
+
"callipers": "calipers",
|
167 |
+
"callisthenics": "calisthenics",
|
168 |
+
"canalise": "canalize",
|
169 |
+
"canalised": "canalized",
|
170 |
+
"canalises": "canalizes",
|
171 |
+
"canalising": "canalizing",
|
172 |
+
"cancelation": "cancellation",
|
173 |
+
"cancelations": "cancellations",
|
174 |
+
"cancelled": "canceled",
|
175 |
+
"cancelling": "canceling",
|
176 |
+
"candour": "candor",
|
177 |
+
"cannibalise": "cannibalize",
|
178 |
+
"cannibalised": "cannibalized",
|
179 |
+
"cannibalises": "cannibalizes",
|
180 |
+
"cannibalising": "cannibalizing",
|
181 |
+
"canonise": "canonize",
|
182 |
+
"canonised": "canonized",
|
183 |
+
"canonises": "canonizes",
|
184 |
+
"canonising": "canonizing",
|
185 |
+
"capitalise": "capitalize",
|
186 |
+
"capitalised": "capitalized",
|
187 |
+
"capitalises": "capitalizes",
|
188 |
+
"capitalising": "capitalizing",
|
189 |
+
"caramelise": "caramelize",
|
190 |
+
"caramelised": "caramelized",
|
191 |
+
"caramelises": "caramelizes",
|
192 |
+
"caramelising": "caramelizing",
|
193 |
+
"carbonise": "carbonize",
|
194 |
+
"carbonised": "carbonized",
|
195 |
+
"carbonises": "carbonizes",
|
196 |
+
"carbonising": "carbonizing",
|
197 |
+
"carolled": "caroled",
|
198 |
+
"carolling": "caroling",
|
199 |
+
"catalogue": "catalog",
|
200 |
+
"catalogued": "cataloged",
|
201 |
+
"catalogues": "catalogs",
|
202 |
+
"cataloguing": "cataloging",
|
203 |
+
"catalyse": "catalyze",
|
204 |
+
"catalysed": "catalyzed",
|
205 |
+
"catalyses": "catalyzes",
|
206 |
+
"catalysing": "catalyzing",
|
207 |
+
"categorise": "categorize",
|
208 |
+
"categorised": "categorized",
|
209 |
+
"categorises": "categorizes",
|
210 |
+
"categorising": "categorizing",
|
211 |
+
"cauterise": "cauterize",
|
212 |
+
"cauterised": "cauterized",
|
213 |
+
"cauterises": "cauterizes",
|
214 |
+
"cauterising": "cauterizing",
|
215 |
+
"cavilled": "caviled",
|
216 |
+
"cavilling": "caviling",
|
217 |
+
"centigramme": "centigram",
|
218 |
+
"centigrammes": "centigrams",
|
219 |
+
"centilitre": "centiliter",
|
220 |
+
"centilitres": "centiliters",
|
221 |
+
"centimetre": "centimeter",
|
222 |
+
"centimetres": "centimeters",
|
223 |
+
"centralise": "centralize",
|
224 |
+
"centralised": "centralized",
|
225 |
+
"centralises": "centralizes",
|
226 |
+
"centralising": "centralizing",
|
227 |
+
"centre": "center",
|
228 |
+
"centred": "centered",
|
229 |
+
"centrefold": "centerfold",
|
230 |
+
"centrefolds": "centerfolds",
|
231 |
+
"centrepiece": "centerpiece",
|
232 |
+
"centrepieces": "centerpieces",
|
233 |
+
"centres": "centers",
|
234 |
+
"channelled": "channeled",
|
235 |
+
"channelling": "channeling",
|
236 |
+
"characterise": "characterize",
|
237 |
+
"characterised": "characterized",
|
238 |
+
"characterises": "characterizes",
|
239 |
+
"characterising": "characterizing",
|
240 |
+
"cheque": "check",
|
241 |
+
"chequebook": "checkbook",
|
242 |
+
"chequebooks": "checkbooks",
|
243 |
+
"chequered": "checkered",
|
244 |
+
"cheques": "checks",
|
245 |
+
"chilli": "chili",
|
246 |
+
"chimaera": "chimera",
|
247 |
+
"chimaeras": "chimeras",
|
248 |
+
"chiselled": "chiseled",
|
249 |
+
"chiselling": "chiseling",
|
250 |
+
"circularise": "circularize",
|
251 |
+
"circularised": "circularized",
|
252 |
+
"circularises": "circularizes",
|
253 |
+
"circularising": "circularizing",
|
254 |
+
"civilise": "civilize",
|
255 |
+
"civilised": "civilized",
|
256 |
+
"civilises": "civilizes",
|
257 |
+
"civilising": "civilizing",
|
258 |
+
"clamour": "clamor",
|
259 |
+
"clamoured": "clamored",
|
260 |
+
"clamouring": "clamoring",
|
261 |
+
"clamours": "clamors",
|
262 |
+
"clangour": "clangor",
|
263 |
+
"clarinettist": "clarinetist",
|
264 |
+
"clarinettists": "clarinetists",
|
265 |
+
"collectivise": "collectivize",
|
266 |
+
"collectivised": "collectivized",
|
267 |
+
"collectivises": "collectivizes",
|
268 |
+
"collectivising": "collectivizing",
|
269 |
+
"colonisation": "colonization",
|
270 |
+
"colonise": "colonize",
|
271 |
+
"colonised": "colonized",
|
272 |
+
"coloniser": "colonizer",
|
273 |
+
"colonisers": "colonizers",
|
274 |
+
"colonises": "colonizes",
|
275 |
+
"colonising": "colonizing",
|
276 |
+
"colour": "color",
|
277 |
+
"colourant": "colorant",
|
278 |
+
"colourants": "colorants",
|
279 |
+
"coloured": "colored",
|
280 |
+
"coloureds": "coloreds",
|
281 |
+
"colourful": "colorful",
|
282 |
+
"colourfully": "colorfully",
|
283 |
+
"colouring": "coloring",
|
284 |
+
"colourize": "colorize",
|
285 |
+
"colourized": "colorized",
|
286 |
+
"colourizes": "colorizes",
|
287 |
+
"colourizing": "colorizing",
|
288 |
+
"colourless": "colorless",
|
289 |
+
"colours": "colors",
|
290 |
+
"commercialise": "commercialize",
|
291 |
+
"commercialised": "commercialized",
|
292 |
+
"commercialises": "commercializes",
|
293 |
+
"commercialising": "commercializing",
|
294 |
+
"compartmentalise": "compartmentalize",
|
295 |
+
"compartmentalised": "compartmentalized",
|
296 |
+
"compartmentalises": "compartmentalizes",
|
297 |
+
"compartmentalising": "compartmentalizing",
|
298 |
+
"computerise": "computerize",
|
299 |
+
"computerised": "computerized",
|
300 |
+
"computerises": "computerizes",
|
301 |
+
"computerising": "computerizing",
|
302 |
+
"conceptualise": "conceptualize",
|
303 |
+
"conceptualised": "conceptualized",
|
304 |
+
"conceptualises": "conceptualizes",
|
305 |
+
"conceptualising": "conceptualizing",
|
306 |
+
"connexion": "connection",
|
307 |
+
"connexions": "connections",
|
308 |
+
"contextualise": "contextualize",
|
309 |
+
"contextualised": "contextualized",
|
310 |
+
"contextualises": "contextualizes",
|
311 |
+
"contextualising": "contextualizing",
|
312 |
+
"cosier": "cozier",
|
313 |
+
"cosies": "cozies",
|
314 |
+
"cosiest": "coziest",
|
315 |
+
"cosily": "cozily",
|
316 |
+
"cosiness": "coziness",
|
317 |
+
"cosy": "cozy",
|
318 |
+
"councillor": "councilor",
|
319 |
+
"councillors": "councilors",
|
320 |
+
"counselled": "counseled",
|
321 |
+
"counselling": "counseling",
|
322 |
+
"counsellor": "counselor",
|
323 |
+
"counsellors": "counselors",
|
324 |
+
"crenelated": "crenellated",
|
325 |
+
"criminalise": "criminalize",
|
326 |
+
"criminalised": "criminalized",
|
327 |
+
"criminalises": "criminalizes",
|
328 |
+
"criminalising": "criminalizing",
|
329 |
+
"criticise": "criticize",
|
330 |
+
"criticised": "criticized",
|
331 |
+
"criticises": "criticizes",
|
332 |
+
"criticising": "criticizing",
|
333 |
+
"crueller": "crueler",
|
334 |
+
"cruellest": "cruelest",
|
335 |
+
"crystallisation": "crystallization",
|
336 |
+
"crystallise": "crystallize",
|
337 |
+
"crystallised": "crystallized",
|
338 |
+
"crystallises": "crystallizes",
|
339 |
+
"crystallising": "crystallizing",
|
340 |
+
"cudgelled": "cudgeled",
|
341 |
+
"cudgelling": "cudgeling",
|
342 |
+
"customise": "customize",
|
343 |
+
"customised": "customized",
|
344 |
+
"customises": "customizes",
|
345 |
+
"customising": "customizing",
|
346 |
+
"cypher": "cipher",
|
347 |
+
"cyphers": "ciphers",
|
348 |
+
"decentralisation": "decentralization",
|
349 |
+
"decentralise": "decentralize",
|
350 |
+
"decentralised": "decentralized",
|
351 |
+
"decentralises": "decentralizes",
|
352 |
+
"decentralising": "decentralizing",
|
353 |
+
"decriminalisation": "decriminalization",
|
354 |
+
"decriminalise": "decriminalize",
|
355 |
+
"decriminalised": "decriminalized",
|
356 |
+
"decriminalises": "decriminalizes",
|
357 |
+
"decriminalising": "decriminalizing",
|
358 |
+
"defence": "defense",
|
359 |
+
"defenceless": "defenseless",
|
360 |
+
"defences": "defenses",
|
361 |
+
"dehumanisation": "dehumanization",
|
362 |
+
"dehumanise": "dehumanize",
|
363 |
+
"dehumanised": "dehumanized",
|
364 |
+
"dehumanises": "dehumanizes",
|
365 |
+
"dehumanising": "dehumanizing",
|
366 |
+
"demeanour": "demeanor",
|
367 |
+
"demilitarisation": "demilitarization",
|
368 |
+
"demilitarise": "demilitarize",
|
369 |
+
"demilitarised": "demilitarized",
|
370 |
+
"demilitarises": "demilitarizes",
|
371 |
+
"demilitarising": "demilitarizing",
|
372 |
+
"demobilisation": "demobilization",
|
373 |
+
"demobilise": "demobilize",
|
374 |
+
"demobilised": "demobilized",
|
375 |
+
"demobilises": "demobilizes",
|
376 |
+
"demobilising": "demobilizing",
|
377 |
+
"democratisation": "democratization",
|
378 |
+
"democratise": "democratize",
|
379 |
+
"democratised": "democratized",
|
380 |
+
"democratises": "democratizes",
|
381 |
+
"democratising": "democratizing",
|
382 |
+
"demonise": "demonize",
|
383 |
+
"demonised": "demonized",
|
384 |
+
"demonises": "demonizes",
|
385 |
+
"demonising": "demonizing",
|
386 |
+
"demoralisation": "demoralization",
|
387 |
+
"demoralise": "demoralize",
|
388 |
+
"demoralised": "demoralized",
|
389 |
+
"demoralises": "demoralizes",
|
390 |
+
"demoralising": "demoralizing",
|
391 |
+
"denationalisation": "denationalization",
|
392 |
+
"denationalise": "denationalize",
|
393 |
+
"denationalised": "denationalized",
|
394 |
+
"denationalises": "denationalizes",
|
395 |
+
"denationalising": "denationalizing",
|
396 |
+
"deodorise": "deodorize",
|
397 |
+
"deodorised": "deodorized",
|
398 |
+
"deodorises": "deodorizes",
|
399 |
+
"deodorising": "deodorizing",
|
400 |
+
"depersonalise": "depersonalize",
|
401 |
+
"depersonalised": "depersonalized",
|
402 |
+
"depersonalises": "depersonalizes",
|
403 |
+
"depersonalising": "depersonalizing",
|
404 |
+
"deputise": "deputize",
|
405 |
+
"deputised": "deputized",
|
406 |
+
"deputises": "deputizes",
|
407 |
+
"deputising": "deputizing",
|
408 |
+
"desensitisation": "desensitization",
|
409 |
+
"desensitise": "desensitize",
|
410 |
+
"desensitised": "desensitized",
|
411 |
+
"desensitises": "desensitizes",
|
412 |
+
"desensitising": "desensitizing",
|
413 |
+
"destabilisation": "destabilization",
|
414 |
+
"destabilise": "destabilize",
|
415 |
+
"destabilised": "destabilized",
|
416 |
+
"destabilises": "destabilizes",
|
417 |
+
"destabilising": "destabilizing",
|
418 |
+
"dialled": "dialed",
|
419 |
+
"dialling": "dialing",
|
420 |
+
"dialogue": "dialog",
|
421 |
+
"dialogues": "dialogs",
|
422 |
+
"diarrhoea": "diarrhea",
|
423 |
+
"digitise": "digitize",
|
424 |
+
"digitised": "digitized",
|
425 |
+
"digitises": "digitizes",
|
426 |
+
"digitising": "digitizing",
|
427 |
+
"disc": "disk",
|
428 |
+
"discolour": "discolor",
|
429 |
+
"discoloured": "discolored",
|
430 |
+
"discolouring": "discoloring",
|
431 |
+
"discolours": "discolors",
|
432 |
+
"discs": "disks",
|
433 |
+
"disembowelled": "disemboweled",
|
434 |
+
"disembowelling": "disemboweling",
|
435 |
+
"disfavour": "disfavor",
|
436 |
+
"dishevelled": "disheveled",
|
437 |
+
"dishonour": "dishonor",
|
438 |
+
"dishonourable": "dishonorable",
|
439 |
+
"dishonourably": "dishonorably",
|
440 |
+
"dishonoured": "dishonored",
|
441 |
+
"dishonouring": "dishonoring",
|
442 |
+
"dishonours": "dishonors",
|
443 |
+
"disorganisation": "disorganization",
|
444 |
+
"disorganised": "disorganized",
|
445 |
+
"distil": "distill",
|
446 |
+
"distils": "distills",
|
447 |
+
"dramatisation": "dramatization",
|
448 |
+
"dramatisations": "dramatizations",
|
449 |
+
"dramatise": "dramatize",
|
450 |
+
"dramatised": "dramatized",
|
451 |
+
"dramatises": "dramatizes",
|
452 |
+
"dramatising": "dramatizing",
|
453 |
+
"draught": "draft",
|
454 |
+
"draughtboard": "draftboard",
|
455 |
+
"draughtboards": "draftboards",
|
456 |
+
"draughtier": "draftier",
|
457 |
+
"draughtiest": "draftiest",
|
458 |
+
"draughts": "drafts",
|
459 |
+
"draughtsman": "draftsman",
|
460 |
+
"draughtsmanship": "draftsmanship",
|
461 |
+
"draughtsmen": "draftsmen",
|
462 |
+
"draughtswoman": "draftswoman",
|
463 |
+
"draughtswomen": "draftswomen",
|
464 |
+
"draughty": "drafty",
|
465 |
+
"drivelled": "driveled",
|
466 |
+
"drivelling": "driveling",
|
467 |
+
"duelled": "dueled",
|
468 |
+
"duelling": "dueling",
|
469 |
+
"economise": "economize",
|
470 |
+
"economised": "economized",
|
471 |
+
"economises": "economizes",
|
472 |
+
"economising": "economizing",
|
473 |
+
"editorialise": "editorialize",
|
474 |
+
"editorialised": "editorialized",
|
475 |
+
"editorialises": "editorializes",
|
476 |
+
"editorialising": "editorializing",
|
477 |
+
"edoema": "edema",
|
478 |
+
"empathise": "empathize",
|
479 |
+
"empathised": "empathized",
|
480 |
+
"empathises": "empathizes",
|
481 |
+
"empathising": "empathizing",
|
482 |
+
"emphasise": "emphasize",
|
483 |
+
"emphasised": "emphasized",
|
484 |
+
"emphasises": "emphasizes",
|
485 |
+
"emphasising": "emphasizing",
|
486 |
+
"enamelled": "enameled",
|
487 |
+
"enamelling": "enameling",
|
488 |
+
"enamoured": "enamored",
|
489 |
+
"encyclopaedia": "encyclopedia",
|
490 |
+
"encyclopaedias": "encyclopedias",
|
491 |
+
"encyclopaedic": "encyclopedic",
|
492 |
+
"endeavour": "endeavor",
|
493 |
+
"endeavoured": "endeavored",
|
494 |
+
"endeavouring": "endeavoring",
|
495 |
+
"endeavours": "endeavors",
|
496 |
+
"energise": "energize",
|
497 |
+
"energised": "energized",
|
498 |
+
"energises": "energizes",
|
499 |
+
"energising": "energizing",
|
500 |
+
"enrol": "enroll",
|
501 |
+
"enrols": "enrolls",
|
502 |
+
"enthral": "enthrall",
|
503 |
+
"enthrals": "enthralls",
|
504 |
+
"epaulette": "epaulet",
|
505 |
+
"epaulettes": "epaulets",
|
506 |
+
"epicentre": "epicenter",
|
507 |
+
"epicentres": "epicenters",
|
508 |
+
"epilogue": "epilog",
|
509 |
+
"epilogues": "epilogs",
|
510 |
+
"epitomise": "epitomize",
|
511 |
+
"epitomised": "epitomized",
|
512 |
+
"epitomises": "epitomizes",
|
513 |
+
"epitomising": "epitomizing",
|
514 |
+
"equalisation": "equalization",
|
515 |
+
"equalise": "equalize",
|
516 |
+
"equalised": "equalized",
|
517 |
+
"equaliser": "equalizer",
|
518 |
+
"equalisers": "equalizers",
|
519 |
+
"equalises": "equalizes",
|
520 |
+
"equalising": "equalizing",
|
521 |
+
"eulogise": "eulogize",
|
522 |
+
"eulogised": "eulogized",
|
523 |
+
"eulogises": "eulogizes",
|
524 |
+
"eulogising": "eulogizing",
|
525 |
+
"evangelise": "evangelize",
|
526 |
+
"evangelised": "evangelized",
|
527 |
+
"evangelises": "evangelizes",
|
528 |
+
"evangelising": "evangelizing",
|
529 |
+
"exorcise": "exorcize",
|
530 |
+
"exorcised": "exorcized",
|
531 |
+
"exorcises": "exorcizes",
|
532 |
+
"exorcising": "exorcizing",
|
533 |
+
"extemporisation": "extemporization",
|
534 |
+
"extemporise": "extemporize",
|
535 |
+
"extemporised": "extemporized",
|
536 |
+
"extemporises": "extemporizes",
|
537 |
+
"extemporising": "extemporizing",
|
538 |
+
"externalisation": "externalization",
|
539 |
+
"externalisations": "externalizations",
|
540 |
+
"externalise": "externalize",
|
541 |
+
"externalised": "externalized",
|
542 |
+
"externalises": "externalizes",
|
543 |
+
"externalising": "externalizing",
|
544 |
+
"factorise": "factorize",
|
545 |
+
"factorised": "factorized",
|
546 |
+
"factorises": "factorizes",
|
547 |
+
"factorising": "factorizing",
|
548 |
+
"faecal": "fecal",
|
549 |
+
"faeces": "feces",
|
550 |
+
"familiarisation": "familiarization",
|
551 |
+
"familiarise": "familiarize",
|
552 |
+
"familiarised": "familiarized",
|
553 |
+
"familiarises": "familiarizes",
|
554 |
+
"familiarising": "familiarizing",
|
555 |
+
"fantasise": "fantasize",
|
556 |
+
"fantasised": "fantasized",
|
557 |
+
"fantasises": "fantasizes",
|
558 |
+
"fantasising": "fantasizing",
|
559 |
+
"favour": "favor",
|
560 |
+
"favourable": "favorable",
|
561 |
+
"favourably": "favorably",
|
562 |
+
"favoured": "favored",
|
563 |
+
"favouring": "favoring",
|
564 |
+
"favourite": "favorite",
|
565 |
+
"favourites": "favorites",
|
566 |
+
"favouritism": "favoritism",
|
567 |
+
"favours": "favors",
|
568 |
+
"feminise": "feminize",
|
569 |
+
"feminised": "feminized",
|
570 |
+
"feminises": "feminizes",
|
571 |
+
"feminising": "feminizing",
|
572 |
+
"fertilisation": "fertilization",
|
573 |
+
"fertilise": "fertilize",
|
574 |
+
"fertilised": "fertilized",
|
575 |
+
"fertiliser": "fertilizer",
|
576 |
+
"fertilisers": "fertilizers",
|
577 |
+
"fertilises": "fertilizes",
|
578 |
+
"fertilising": "fertilizing",
|
579 |
+
"fervour": "fervor",
|
580 |
+
"fibre": "fiber",
|
581 |
+
"fibreglass": "fiberglass",
|
582 |
+
"fibres": "fibers",
|
583 |
+
"fictionalisation": "fictionalization",
|
584 |
+
"fictionalisations": "fictionalizations",
|
585 |
+
"fictionalise": "fictionalize",
|
586 |
+
"fictionalised": "fictionalized",
|
587 |
+
"fictionalises": "fictionalizes",
|
588 |
+
"fictionalising": "fictionalizing",
|
589 |
+
"fillet": "filet",
|
590 |
+
"filleted": "fileted",
|
591 |
+
"filleting": "fileting",
|
592 |
+
"fillets": "filets",
|
593 |
+
"finalisation": "finalization",
|
594 |
+
"finalise": "finalize",
|
595 |
+
"finalised": "finalized",
|
596 |
+
"finalises": "finalizes",
|
597 |
+
"finalising": "finalizing",
|
598 |
+
"flautist": "flutist",
|
599 |
+
"flautists": "flutists",
|
600 |
+
"flavour": "flavor",
|
601 |
+
"flavoured": "flavored",
|
602 |
+
"flavouring": "flavoring",
|
603 |
+
"flavourings": "flavorings",
|
604 |
+
"flavourless": "flavorless",
|
605 |
+
"flavours": "flavors",
|
606 |
+
"flavoursome": "flavorsome",
|
607 |
+
"flyer / flier": "flier / flyer",
|
608 |
+
"foetal": "fetal",
|
609 |
+
"foetid": "fetid",
|
610 |
+
"foetus": "fetus",
|
611 |
+
"foetuses": "fetuses",
|
612 |
+
"formalisation": "formalization",
|
613 |
+
"formalise": "formalize",
|
614 |
+
"formalised": "formalized",
|
615 |
+
"formalises": "formalizes",
|
616 |
+
"formalising": "formalizing",
|
617 |
+
"fossilisation": "fossilization",
|
618 |
+
"fossilise": "fossilize",
|
619 |
+
"fossilised": "fossilized",
|
620 |
+
"fossilises": "fossilizes",
|
621 |
+
"fossilising": "fossilizing",
|
622 |
+
"fraternisation": "fraternization",
|
623 |
+
"fraternise": "fraternize",
|
624 |
+
"fraternised": "fraternized",
|
625 |
+
"fraternises": "fraternizes",
|
626 |
+
"fraternising": "fraternizing",
|
627 |
+
"fulfil": "fulfill",
|
628 |
+
"fulfilment": "fulfillment",
|
629 |
+
"fulfils": "fulfills",
|
630 |
+
"funnelled": "funneled",
|
631 |
+
"funnelling": "funneling",
|
632 |
+
"gage": "gauge",
|
633 |
+
"gaged": "gauged",
|
634 |
+
"gages": "gauges",
|
635 |
+
"gaging": "gauging",
|
636 |
+
"galvanise": "galvanize",
|
637 |
+
"galvanised": "galvanized",
|
638 |
+
"galvanises": "galvanizes",
|
639 |
+
"galvanising": "galvanizing",
|
640 |
+
"gambolled": "gamboled",
|
641 |
+
"gambolling": "gamboling",
|
642 |
+
"gaol": "jail",
|
643 |
+
"gaolbird": "jailbird",
|
644 |
+
"gaolbirds": "jailbirds",
|
645 |
+
"gaolbreak": "jailbreak",
|
646 |
+
"gaolbreaks": "jailbreaks",
|
647 |
+
"gaoled": "jailed",
|
648 |
+
"gaoler": "jailer",
|
649 |
+
"gaolers": "jailers",
|
650 |
+
"gaoling": "jailing",
|
651 |
+
"gaols": "jails",
|
652 |
+
"gasses": "gases",
|
653 |
+
"generalisation": "generalization",
|
654 |
+
"generalisations": "generalizations",
|
655 |
+
"generalise": "generalize",
|
656 |
+
"generalised": "generalized",
|
657 |
+
"generalises": "generalizes",
|
658 |
+
"generalising": "generalizing",
|
659 |
+
"ghettoise": "ghettoize",
|
660 |
+
"ghettoised": "ghettoized",
|
661 |
+
"ghettoises": "ghettoizes",
|
662 |
+
"ghettoising": "ghettoizing",
|
663 |
+
"gipsies": "gypsies",
|
664 |
+
"glamor": "glamour",
|
665 |
+
"glamorise": "glamorize",
|
666 |
+
"glamorised": "glamorized",
|
667 |
+
"glamorises": "glamorizes",
|
668 |
+
"glamorising": "glamorizing",
|
669 |
+
"globalisation": "globalization",
|
670 |
+
"globalise": "globalize",
|
671 |
+
"globalised": "globalized",
|
672 |
+
"globalises": "globalizes",
|
673 |
+
"globalising": "globalizing",
|
674 |
+
"glueing": "gluing",
|
675 |
+
"goitre": "goiter",
|
676 |
+
"goitres": "goiters",
|
677 |
+
"gonorrhoea": "gonorrhea",
|
678 |
+
"gramme": "gram",
|
679 |
+
"grammes": "grams",
|
680 |
+
"gravelled": "graveled",
|
681 |
+
"grey": "gray",
|
682 |
+
"greyed": "grayed",
|
683 |
+
"greying": "graying",
|
684 |
+
"greyish": "grayish",
|
685 |
+
"greyness": "grayness",
|
686 |
+
"greys": "grays",
|
687 |
+
"grovelled": "groveled",
|
688 |
+
"grovelling": "groveling",
|
689 |
+
"groyne": "groin",
|
690 |
+
"groynes": "groins",
|
691 |
+
"gruelling": "grueling",
|
692 |
+
"gruellingly": "gruelingly",
|
693 |
+
"gryphon": "griffin",
|
694 |
+
"gryphons": "griffins",
|
695 |
+
"gynaecological": "gynecological",
|
696 |
+
"gynaecologist": "gynecologist",
|
697 |
+
"gynaecologists": "gynecologists",
|
698 |
+
"gynaecology": "gynecology",
|
699 |
+
"haematological": "hematological",
|
700 |
+
"haematologist": "hematologist",
|
701 |
+
"haematologists": "hematologists",
|
702 |
+
"haematology": "hematology",
|
703 |
+
"haemoglobin": "hemoglobin",
|
704 |
+
"haemophilia": "hemophilia",
|
705 |
+
"haemophiliac": "hemophiliac",
|
706 |
+
"haemophiliacs": "hemophiliacs",
|
707 |
+
"haemorrhage": "hemorrhage",
|
708 |
+
"haemorrhaged": "hemorrhaged",
|
709 |
+
"haemorrhages": "hemorrhages",
|
710 |
+
"haemorrhaging": "hemorrhaging",
|
711 |
+
"haemorrhoids": "hemorrhoids",
|
712 |
+
"harbour": "harbor",
|
713 |
+
"harboured": "harbored",
|
714 |
+
"harbouring": "harboring",
|
715 |
+
"harbours": "harbors",
|
716 |
+
"harmonisation": "harmonization",
|
717 |
+
"harmonise": "harmonize",
|
718 |
+
"harmonised": "harmonized",
|
719 |
+
"harmonises": "harmonizes",
|
720 |
+
"harmonising": "harmonizing",
|
721 |
+
"homoeopath": "homeopath",
|
722 |
+
"homoeopathic": "homeopathic",
|
723 |
+
"homoeopaths": "homeopaths",
|
724 |
+
"homoeopathy": "homeopathy",
|
725 |
+
"homogenise": "homogenize",
|
726 |
+
"homogenised": "homogenized",
|
727 |
+
"homogenises": "homogenizes",
|
728 |
+
"homogenising": "homogenizing",
|
729 |
+
"honour": "honor",
|
730 |
+
"honourable": "honorable",
|
731 |
+
"honourably": "honorably",
|
732 |
+
"honoured": "honored",
|
733 |
+
"honouring": "honoring",
|
734 |
+
"honours": "honors",
|
735 |
+
"hospitalisation": "hospitalization",
|
736 |
+
"hospitalise": "hospitalize",
|
737 |
+
"hospitalised": "hospitalized",
|
738 |
+
"hospitalises": "hospitalizes",
|
739 |
+
"hospitalising": "hospitalizing",
|
740 |
+
"humanise": "humanize",
|
741 |
+
"humanised": "humanized",
|
742 |
+
"humanises": "humanizes",
|
743 |
+
"humanising": "humanizing",
|
744 |
+
"humour": "humor",
|
745 |
+
"humoured": "humored",
|
746 |
+
"humouring": "humoring",
|
747 |
+
"humourless": "humorless",
|
748 |
+
"humours": "humors",
|
749 |
+
"hybridise": "hybridize",
|
750 |
+
"hybridised": "hybridized",
|
751 |
+
"hybridises": "hybridizes",
|
752 |
+
"hybridising": "hybridizing",
|
753 |
+
"hypnotise": "hypnotize",
|
754 |
+
"hypnotised": "hypnotized",
|
755 |
+
"hypnotises": "hypnotizes",
|
756 |
+
"hypnotising": "hypnotizing",
|
757 |
+
"hypothesise": "hypothesize",
|
758 |
+
"hypothesised": "hypothesized",
|
759 |
+
"hypothesises": "hypothesizes",
|
760 |
+
"hypothesising": "hypothesizing",
|
761 |
+
"idealisation": "idealization",
|
762 |
+
"idealise": "idealize",
|
763 |
+
"idealised": "idealized",
|
764 |
+
"idealises": "idealizes",
|
765 |
+
"idealising": "idealizing",
|
766 |
+
"idolise": "idolize",
|
767 |
+
"idolised": "idolized",
|
768 |
+
"idolises": "idolizes",
|
769 |
+
"idolising": "idolizing",
|
770 |
+
"immobilisation": "immobilization",
|
771 |
+
"immobilise": "immobilize",
|
772 |
+
"immobilised": "immobilized",
|
773 |
+
"immobiliser": "immobilizer",
|
774 |
+
"immobilisers": "immobilizers",
|
775 |
+
"immobilises": "immobilizes",
|
776 |
+
"immobilising": "immobilizing",
|
777 |
+
"immortalise": "immortalize",
|
778 |
+
"immortalised": "immortalized",
|
779 |
+
"immortalises": "immortalizes",
|
780 |
+
"immortalising": "immortalizing",
|
781 |
+
"immunisation": "immunization",
|
782 |
+
"immunise": "immunize",
|
783 |
+
"immunised": "immunized",
|
784 |
+
"immunises": "immunizes",
|
785 |
+
"immunising": "immunizing",
|
786 |
+
"impanelled": "impaneled",
|
787 |
+
"impanelling": "impaneling",
|
788 |
+
"imperilled": "imperiled",
|
789 |
+
"imperilling": "imperiling",
|
790 |
+
"individualise": "individualize",
|
791 |
+
"individualised": "individualized",
|
792 |
+
"individualises": "individualizes",
|
793 |
+
"individualising": "individualizing",
|
794 |
+
"industrialise": "industrialize",
|
795 |
+
"industrialised": "industrialized",
|
796 |
+
"industrialises": "industrializes",
|
797 |
+
"industrialising": "industrializing",
|
798 |
+
"inflexion": "inflection",
|
799 |
+
"inflexions": "inflections",
|
800 |
+
"initialise": "initialize",
|
801 |
+
"initialised": "initialized",
|
802 |
+
"initialises": "initializes",
|
803 |
+
"initialising": "initializing",
|
804 |
+
"initialled": "initialed",
|
805 |
+
"initialling": "initialing",
|
806 |
+
"instal": "install",
|
807 |
+
"instalment": "installment",
|
808 |
+
"instalments": "installments",
|
809 |
+
"instals": "installs",
|
810 |
+
"instil": "instill",
|
811 |
+
"instils": "instills",
|
812 |
+
"institutionalisation": "institutionalization",
|
813 |
+
"institutionalise": "institutionalize",
|
814 |
+
"institutionalised": "institutionalized",
|
815 |
+
"institutionalises": "institutionalizes",
|
816 |
+
"institutionalising": "institutionalizing",
|
817 |
+
"intellectualise": "intellectualize",
|
818 |
+
"intellectualised": "intellectualized",
|
819 |
+
"intellectualises": "intellectualizes",
|
820 |
+
"intellectualising": "intellectualizing",
|
821 |
+
"internalisation": "internalization",
|
822 |
+
"internalise": "internalize",
|
823 |
+
"internalised": "internalized",
|
824 |
+
"internalises": "internalizes",
|
825 |
+
"internalising": "internalizing",
|
826 |
+
"internationalisation": "internationalization",
|
827 |
+
"internationalise": "internationalize",
|
828 |
+
"internationalised": "internationalized",
|
829 |
+
"internationalises": "internationalizes",
|
830 |
+
"internationalising": "internationalizing",
|
831 |
+
"ionisation": "ionization",
|
832 |
+
"ionise": "ionize",
|
833 |
+
"ionised": "ionized",
|
834 |
+
"ioniser": "ionizer",
|
835 |
+
"ionisers": "ionizers",
|
836 |
+
"ionises": "ionizes",
|
837 |
+
"ionising": "ionizing",
|
838 |
+
"italicise": "italicize",
|
839 |
+
"italicised": "italicized",
|
840 |
+
"italicises": "italicizes",
|
841 |
+
"italicising": "italicizing",
|
842 |
+
"itemise": "itemize",
|
843 |
+
"itemised": "itemized",
|
844 |
+
"itemises": "itemizes",
|
845 |
+
"itemising": "itemizing",
|
846 |
+
"jeopardise": "jeopardize",
|
847 |
+
"jeopardised": "jeopardized",
|
848 |
+
"jeopardises": "jeopardizes",
|
849 |
+
"jeopardising": "jeopardizing",
|
850 |
+
"jewelled": "jeweled",
|
851 |
+
"jeweller": "jeweler",
|
852 |
+
"jewellers": "jewelers",
|
853 |
+
"jewellery": "jewelry",
|
854 |
+
"judgement": "judgment",
|
855 |
+
"kilogramme": "kilogram",
|
856 |
+
"kilogrammes": "kilograms",
|
857 |
+
"kilometre": "kilometer",
|
858 |
+
"kilometres": "kilometers",
|
859 |
+
"labelled": "labeled",
|
860 |
+
"labelling": "labeling",
|
861 |
+
"labour": "labor",
|
862 |
+
"laboured": "labored",
|
863 |
+
"labourer": "laborer",
|
864 |
+
"labourers": "laborers",
|
865 |
+
"labouring": "laboring",
|
866 |
+
"labours": "labors",
|
867 |
+
"lacklustre": "lackluster",
|
868 |
+
"legalisation": "legalization",
|
869 |
+
"legalise": "legalize",
|
870 |
+
"legalised": "legalized",
|
871 |
+
"legalises": "legalizes",
|
872 |
+
"legalising": "legalizing",
|
873 |
+
"legitimise": "legitimize",
|
874 |
+
"legitimised": "legitimized",
|
875 |
+
"legitimises": "legitimizes",
|
876 |
+
"legitimising": "legitimizing",
|
877 |
+
"leukaemia": "leukemia",
|
878 |
+
"levelled": "leveled",
|
879 |
+
"leveller": "leveler",
|
880 |
+
"levellers": "levelers",
|
881 |
+
"levelling": "leveling",
|
882 |
+
"libelled": "libeled",
|
883 |
+
"libelling": "libeling",
|
884 |
+
"libellous": "libelous",
|
885 |
+
"liberalisation": "liberalization",
|
886 |
+
"liberalise": "liberalize",
|
887 |
+
"liberalised": "liberalized",
|
888 |
+
"liberalises": "liberalizes",
|
889 |
+
"liberalising": "liberalizing",
|
890 |
+
"licence": "license",
|
891 |
+
"licenced": "licensed",
|
892 |
+
"licences": "licenses",
|
893 |
+
"licencing": "licensing",
|
894 |
+
"likeable": "likable",
|
895 |
+
"lionisation": "lionization",
|
896 |
+
"lionise": "lionize",
|
897 |
+
"lionised": "lionized",
|
898 |
+
"lionises": "lionizes",
|
899 |
+
"lionising": "lionizing",
|
900 |
+
"liquidise": "liquidize",
|
901 |
+
"liquidised": "liquidized",
|
902 |
+
"liquidiser": "liquidizer",
|
903 |
+
"liquidisers": "liquidizers",
|
904 |
+
"liquidises": "liquidizes",
|
905 |
+
"liquidising": "liquidizing",
|
906 |
+
"litre": "liter",
|
907 |
+
"litres": "liters",
|
908 |
+
"localise": "localize",
|
909 |
+
"localised": "localized",
|
910 |
+
"localises": "localizes",
|
911 |
+
"localising": "localizing",
|
912 |
+
"louvre": "louver",
|
913 |
+
"louvred": "louvered",
|
914 |
+
"louvres": "louvers",
|
915 |
+
"lustre": "luster",
|
916 |
+
"magnetise": "magnetize",
|
917 |
+
"magnetised": "magnetized",
|
918 |
+
"magnetises": "magnetizes",
|
919 |
+
"magnetising": "magnetizing",
|
920 |
+
"manoeuvrability": "maneuverability",
|
921 |
+
"manoeuvrable": "maneuverable",
|
922 |
+
"manoeuvre": "maneuver",
|
923 |
+
"manoeuvred": "maneuvered",
|
924 |
+
"manoeuvres": "maneuvers",
|
925 |
+
"manoeuvring": "maneuvering",
|
926 |
+
"manoeuvrings": "maneuverings",
|
927 |
+
"marginalisation": "marginalization",
|
928 |
+
"marginalise": "marginalize",
|
929 |
+
"marginalised": "marginalized",
|
930 |
+
"marginalises": "marginalizes",
|
931 |
+
"marginalising": "marginalizing",
|
932 |
+
"marshalled": "marshaled",
|
933 |
+
"marshalling": "marshaling",
|
934 |
+
"marvelled": "marveled",
|
935 |
+
"marvelling": "marveling",
|
936 |
+
"marvellous": "marvelous",
|
937 |
+
"marvellously": "marvelously",
|
938 |
+
"materialisation": "materialization",
|
939 |
+
"materialise": "materialize",
|
940 |
+
"materialised": "materialized",
|
941 |
+
"materialises": "materializes",
|
942 |
+
"materialising": "materializing",
|
943 |
+
"maximisation": "maximization",
|
944 |
+
"maximise": "maximize",
|
945 |
+
"maximised": "maximized",
|
946 |
+
"maximises": "maximizes",
|
947 |
+
"maximising": "maximizing",
|
948 |
+
"meagre": "meager",
|
949 |
+
"mechanisation": "mechanization",
|
950 |
+
"mechanise": "mechanize",
|
951 |
+
"mechanised": "mechanized",
|
952 |
+
"mechanises": "mechanizes",
|
953 |
+
"mechanising": "mechanizing",
|
954 |
+
"mediaeval": "medieval",
|
955 |
+
"memorialise": "memorialize",
|
956 |
+
"memorialised": "memorialized",
|
957 |
+
"memorialises": "memorializes",
|
958 |
+
"memorialising": "memorializing",
|
959 |
+
"memorise": "memorize",
|
960 |
+
"memorised": "memorized",
|
961 |
+
"memorises": "memorizes",
|
962 |
+
"memorising": "memorizing",
|
963 |
+
"mesmerise": "mesmerize",
|
964 |
+
"mesmerised": "mesmerized",
|
965 |
+
"mesmerises": "mesmerizes",
|
966 |
+
"mesmerising": "mesmerizing",
|
967 |
+
"metabolise": "metabolize",
|
968 |
+
"metabolised": "metabolized",
|
969 |
+
"metabolises": "metabolizes",
|
970 |
+
"metabolising": "metabolizing",
|
971 |
+
"metre": "meter",
|
972 |
+
"metres": "meters",
|
973 |
+
"mhm": "hmm",
|
974 |
+
"micrometre": "micrometer",
|
975 |
+
"micrometres": "micrometers",
|
976 |
+
"militarise": "militarize",
|
977 |
+
"militarised": "militarized",
|
978 |
+
"militarises": "militarizes",
|
979 |
+
"militarising": "militarizing",
|
980 |
+
"milligramme": "milligram",
|
981 |
+
"milligrammes": "milligrams",
|
982 |
+
"millilitre": "milliliter",
|
983 |
+
"millilitres": "milliliters",
|
984 |
+
"millimetre": "millimeter",
|
985 |
+
"millimetres": "millimeters",
|
986 |
+
"miniaturisation": "miniaturization",
|
987 |
+
"miniaturise": "miniaturize",
|
988 |
+
"miniaturised": "miniaturized",
|
989 |
+
"miniaturises": "miniaturizes",
|
990 |
+
"miniaturising": "miniaturizing",
|
991 |
+
"minibusses": "minibuses",
|
992 |
+
"minimise": "minimize",
|
993 |
+
"minimised": "minimized",
|
994 |
+
"minimises": "minimizes",
|
995 |
+
"minimising": "minimizing",
|
996 |
+
"misbehaviour": "misbehavior",
|
997 |
+
"misdemeanour": "misdemeanor",
|
998 |
+
"misdemeanours": "misdemeanors",
|
999 |
+
"misspelt": "misspelled",
|
1000 |
+
"mitre": "miter",
|
1001 |
+
"mitres": "miters",
|
1002 |
+
"mm": "hmm",
|
1003 |
+
"mmm": "hmm",
|
1004 |
+
"mobilisation": "mobilization",
|
1005 |
+
"mobilise": "mobilize",
|
1006 |
+
"mobilised": "mobilized",
|
1007 |
+
"mobilises": "mobilizes",
|
1008 |
+
"mobilising": "mobilizing",
|
1009 |
+
"modelled": "modeled",
|
1010 |
+
"modeller": "modeler",
|
1011 |
+
"modellers": "modelers",
|
1012 |
+
"modelling": "modeling",
|
1013 |
+
"modernise": "modernize",
|
1014 |
+
"modernised": "modernized",
|
1015 |
+
"modernises": "modernizes",
|
1016 |
+
"modernising": "modernizing",
|
1017 |
+
"moisturise": "moisturize",
|
1018 |
+
"moisturised": "moisturized",
|
1019 |
+
"moisturiser": "moisturizer",
|
1020 |
+
"moisturisers": "moisturizers",
|
1021 |
+
"moisturises": "moisturizes",
|
1022 |
+
"moisturising": "moisturizing",
|
1023 |
+
"monologue": "monolog",
|
1024 |
+
"monologues": "monologs",
|
1025 |
+
"monopolisation": "monopolization",
|
1026 |
+
"monopolise": "monopolize",
|
1027 |
+
"monopolised": "monopolized",
|
1028 |
+
"monopolises": "monopolizes",
|
1029 |
+
"monopolising": "monopolizing",
|
1030 |
+
"moralise": "moralize",
|
1031 |
+
"moralised": "moralized",
|
1032 |
+
"moralises": "moralizes",
|
1033 |
+
"moralising": "moralizing",
|
1034 |
+
"motorised": "motorized",
|
1035 |
+
"mould": "mold",
|
1036 |
+
"moulded": "molded",
|
1037 |
+
"moulder": "molder",
|
1038 |
+
"mouldered": "moldered",
|
1039 |
+
"mouldering": "moldering",
|
1040 |
+
"moulders": "molders",
|
1041 |
+
"mouldier": "moldier",
|
1042 |
+
"mouldiest": "moldiest",
|
1043 |
+
"moulding": "molding",
|
1044 |
+
"mouldings": "moldings",
|
1045 |
+
"moulds": "molds",
|
1046 |
+
"mouldy": "moldy",
|
1047 |
+
"moult": "molt",
|
1048 |
+
"moulted": "molted",
|
1049 |
+
"moulting": "molting",
|
1050 |
+
"moults": "molts",
|
1051 |
+
"moustache": "mustache",
|
1052 |
+
"moustached": "mustached",
|
1053 |
+
"moustaches": "mustaches",
|
1054 |
+
"moustachioed": "mustachioed",
|
1055 |
+
"multicoloured": "multicolored",
|
1056 |
+
"nationalisation": "nationalization",
|
1057 |
+
"nationalisations": "nationalizations",
|
1058 |
+
"nationalise": "nationalize",
|
1059 |
+
"nationalised": "nationalized",
|
1060 |
+
"nationalises": "nationalizes",
|
1061 |
+
"nationalising": "nationalizing",
|
1062 |
+
"naturalisation": "naturalization",
|
1063 |
+
"naturalise": "naturalize",
|
1064 |
+
"naturalised": "naturalized",
|
1065 |
+
"naturalises": "naturalizes",
|
1066 |
+
"naturalising": "naturalizing",
|
1067 |
+
"neighbour": "neighbor",
|
1068 |
+
"neighbourhood": "neighborhood",
|
1069 |
+
"neighbourhoods": "neighborhoods",
|
1070 |
+
"neighbouring": "neighboring",
|
1071 |
+
"neighbourliness": "neighborliness",
|
1072 |
+
"neighbourly": "neighborly",
|
1073 |
+
"neighbours": "neighbors",
|
1074 |
+
"neutralisation": "neutralization",
|
1075 |
+
"neutralise": "neutralize",
|
1076 |
+
"neutralised": "neutralized",
|
1077 |
+
"neutralises": "neutralizes",
|
1078 |
+
"neutralising": "neutralizing",
|
1079 |
+
"normalisation": "normalization",
|
1080 |
+
"normalise": "normalize",
|
1081 |
+
"normalised": "normalized",
|
1082 |
+
"normalises": "normalizes",
|
1083 |
+
"normalising": "normalizing",
|
1084 |
+
"odour": "odor",
|
1085 |
+
"odourless": "odorless",
|
1086 |
+
"odours": "odors",
|
1087 |
+
"oesophagus": "esophagus",
|
1088 |
+
"oesophaguses": "esophaguses",
|
1089 |
+
"oestrogen": "estrogen",
|
1090 |
+
"offence": "offense",
|
1091 |
+
"offences": "offenses",
|
1092 |
+
"omelette": "omelet",
|
1093 |
+
"omelettes": "omelets",
|
1094 |
+
"optimise": "optimize",
|
1095 |
+
"optimised": "optimized",
|
1096 |
+
"optimises": "optimizes",
|
1097 |
+
"optimising": "optimizing",
|
1098 |
+
"organisation": "organization",
|
1099 |
+
"organisational": "organizational",
|
1100 |
+
"organisations": "organizations",
|
1101 |
+
"organise": "organize",
|
1102 |
+
"organised": "organized",
|
1103 |
+
"organiser": "organizer",
|
1104 |
+
"organisers": "organizers",
|
1105 |
+
"organises": "organizes",
|
1106 |
+
"organising": "organizing",
|
1107 |
+
"orthopaedic": "orthopedic",
|
1108 |
+
"orthopaedics": "orthopedics",
|
1109 |
+
"ostracise": "ostracize",
|
1110 |
+
"ostracised": "ostracized",
|
1111 |
+
"ostracises": "ostracizes",
|
1112 |
+
"ostracising": "ostracizing",
|
1113 |
+
"outmanoeuvre": "outmaneuver",
|
1114 |
+
"outmanoeuvred": "outmaneuvered",
|
1115 |
+
"outmanoeuvres": "outmaneuvers",
|
1116 |
+
"outmanoeuvring": "outmaneuvering",
|
1117 |
+
"overemphasise": "overemphasize",
|
1118 |
+
"overemphasised": "overemphasized",
|
1119 |
+
"overemphasises": "overemphasizes",
|
1120 |
+
"overemphasising": "overemphasizing",
|
1121 |
+
"oxidisation": "oxidization",
|
1122 |
+
"oxidise": "oxidize",
|
1123 |
+
"oxidised": "oxidized",
|
1124 |
+
"oxidises": "oxidizes",
|
1125 |
+
"oxidising": "oxidizing",
|
1126 |
+
"paederast": "pederast",
|
1127 |
+
"paederasts": "pederasts",
|
1128 |
+
"paediatric": "pediatric",
|
1129 |
+
"paediatrician": "pediatrician",
|
1130 |
+
"paediatricians": "pediatricians",
|
1131 |
+
"paediatrics": "pediatrics",
|
1132 |
+
"paedophile": "pedophile",
|
1133 |
+
"paedophiles": "pedophiles",
|
1134 |
+
"paedophilia": "pedophilia",
|
1135 |
+
"palaeolithic": "paleolithic",
|
1136 |
+
"palaeontologist": "paleontologist",
|
1137 |
+
"palaeontologists": "paleontologists",
|
1138 |
+
"palaeontology": "paleontology",
|
1139 |
+
"panelled": "paneled",
|
1140 |
+
"panelling": "paneling",
|
1141 |
+
"panellist": "panelist",
|
1142 |
+
"panellists": "panelists",
|
1143 |
+
"paralyse": "paralyze",
|
1144 |
+
"paralysed": "paralyzed",
|
1145 |
+
"paralyses": "paralyzes",
|
1146 |
+
"paralysing": "paralyzing",
|
1147 |
+
"parcelled": "parceled",
|
1148 |
+
"parcelling": "parceling",
|
1149 |
+
"parlour": "parlor",
|
1150 |
+
"parlours": "parlors",
|
1151 |
+
"particularise": "particularize",
|
1152 |
+
"particularised": "particularized",
|
1153 |
+
"particularises": "particularizes",
|
1154 |
+
"particularising": "particularizing",
|
1155 |
+
"passivisation": "passivization",
|
1156 |
+
"passivise": "passivize",
|
1157 |
+
"passivised": "passivized",
|
1158 |
+
"passivises": "passivizes",
|
1159 |
+
"passivising": "passivizing",
|
1160 |
+
"pasteurisation": "pasteurization",
|
1161 |
+
"pasteurise": "pasteurize",
|
1162 |
+
"pasteurised": "pasteurized",
|
1163 |
+
"pasteurises": "pasteurizes",
|
1164 |
+
"pasteurising": "pasteurizing",
|
1165 |
+
"patronise": "patronize",
|
1166 |
+
"patronised": "patronized",
|
1167 |
+
"patronises": "patronizes",
|
1168 |
+
"patronising": "patronizing",
|
1169 |
+
"patronisingly": "patronizingly",
|
1170 |
+
"pedalled": "pedaled",
|
1171 |
+
"pedalling": "pedaling",
|
1172 |
+
"pedestrianisation": "pedestrianization",
|
1173 |
+
"pedestrianise": "pedestrianize",
|
1174 |
+
"pedestrianised": "pedestrianized",
|
1175 |
+
"pedestrianises": "pedestrianizes",
|
1176 |
+
"pedestrianising": "pedestrianizing",
|
1177 |
+
"penalise": "penalize",
|
1178 |
+
"penalised": "penalized",
|
1179 |
+
"penalises": "penalizes",
|
1180 |
+
"penalising": "penalizing",
|
1181 |
+
"pencilled": "penciled",
|
1182 |
+
"pencilling": "penciling",
|
1183 |
+
"personalise": "personalize",
|
1184 |
+
"personalised": "personalized",
|
1185 |
+
"personalises": "personalizes",
|
1186 |
+
"personalising": "personalizing",
|
1187 |
+
"pharmacopoeia": "pharmacopeia",
|
1188 |
+
"pharmacopoeias": "pharmacopeias",
|
1189 |
+
"philosophise": "philosophize",
|
1190 |
+
"philosophised": "philosophized",
|
1191 |
+
"philosophises": "philosophizes",
|
1192 |
+
"philosophising": "philosophizing",
|
1193 |
+
"philtre": "filter",
|
1194 |
+
"philtres": "filters",
|
1195 |
+
"phoney": "phony",
|
1196 |
+
"plagiarise": "plagiarize",
|
1197 |
+
"plagiarised": "plagiarized",
|
1198 |
+
"plagiarises": "plagiarizes",
|
1199 |
+
"plagiarising": "plagiarizing",
|
1200 |
+
"plough": "plow",
|
1201 |
+
"ploughed": "plowed",
|
1202 |
+
"ploughing": "plowing",
|
1203 |
+
"ploughman": "plowman",
|
1204 |
+
"ploughmen": "plowmen",
|
1205 |
+
"ploughs": "plows",
|
1206 |
+
"ploughshare": "plowshare",
|
1207 |
+
"ploughshares": "plowshares",
|
1208 |
+
"polarisation": "polarization",
|
1209 |
+
"polarise": "polarize",
|
1210 |
+
"polarised": "polarized",
|
1211 |
+
"polarises": "polarizes",
|
1212 |
+
"polarising": "polarizing",
|
1213 |
+
"politicisation": "politicization",
|
1214 |
+
"politicise": "politicize",
|
1215 |
+
"politicised": "politicized",
|
1216 |
+
"politicises": "politicizes",
|
1217 |
+
"politicising": "politicizing",
|
1218 |
+
"popularisation": "popularization",
|
1219 |
+
"popularise": "popularize",
|
1220 |
+
"popularised": "popularized",
|
1221 |
+
"popularises": "popularizes",
|
1222 |
+
"popularising": "popularizing",
|
1223 |
+
"pouffe": "pouf",
|
1224 |
+
"pouffes": "poufs",
|
1225 |
+
"practise": "practice",
|
1226 |
+
"practised": "practiced",
|
1227 |
+
"practises": "practices",
|
1228 |
+
"practising": "practicing",
|
1229 |
+
"praesidium": "presidium",
|
1230 |
+
"praesidiums": "presidiums",
|
1231 |
+
"pressurisation": "pressurization",
|
1232 |
+
"pressurise": "pressurize",
|
1233 |
+
"pressurised": "pressurized",
|
1234 |
+
"pressurises": "pressurizes",
|
1235 |
+
"pressurising": "pressurizing",
|
1236 |
+
"pretence": "pretense",
|
1237 |
+
"pretences": "pretenses",
|
1238 |
+
"primaeval": "primeval",
|
1239 |
+
"prioritisation": "prioritization",
|
1240 |
+
"prioritise": "prioritize",
|
1241 |
+
"prioritised": "prioritized",
|
1242 |
+
"prioritises": "prioritizes",
|
1243 |
+
"prioritising": "prioritizing",
|
1244 |
+
"privatisation": "privatization",
|
1245 |
+
"privatisations": "privatizations",
|
1246 |
+
"privatise": "privatize",
|
1247 |
+
"privatised": "privatized",
|
1248 |
+
"privatises": "privatizes",
|
1249 |
+
"privatising": "privatizing",
|
1250 |
+
"professionalisation": "professionalization",
|
1251 |
+
"professionalise": "professionalize",
|
1252 |
+
"professionalised": "professionalized",
|
1253 |
+
"professionalises": "professionalizes",
|
1254 |
+
"professionalising": "professionalizing",
|
1255 |
+
"programme": "program",
|
1256 |
+
"programmes": "programs",
|
1257 |
+
"prologue": "prolog",
|
1258 |
+
"prologues": "prologs",
|
1259 |
+
"propagandise": "propagandize",
|
1260 |
+
"propagandised": "propagandized",
|
1261 |
+
"propagandises": "propagandizes",
|
1262 |
+
"propagandising": "propagandizing",
|
1263 |
+
"proselytise": "proselytize",
|
1264 |
+
"proselytised": "proselytized",
|
1265 |
+
"proselytiser": "proselytizer",
|
1266 |
+
"proselytisers": "proselytizers",
|
1267 |
+
"proselytises": "proselytizes",
|
1268 |
+
"proselytising": "proselytizing",
|
1269 |
+
"psychoanalyse": "psychoanalyze",
|
1270 |
+
"psychoanalysed": "psychoanalyzed",
|
1271 |
+
"psychoanalyses": "psychoanalyzes",
|
1272 |
+
"psychoanalysing": "psychoanalyzing",
|
1273 |
+
"publicise": "publicize",
|
1274 |
+
"publicised": "publicized",
|
1275 |
+
"publicises": "publicizes",
|
1276 |
+
"publicising": "publicizing",
|
1277 |
+
"pulverisation": "pulverization",
|
1278 |
+
"pulverise": "pulverize",
|
1279 |
+
"pulverised": "pulverized",
|
1280 |
+
"pulverises": "pulverizes",
|
1281 |
+
"pulverising": "pulverizing",
|
1282 |
+
"pummelled": "pummel",
|
1283 |
+
"pummelling": "pummeled",
|
1284 |
+
"pyjama": "pajama",
|
1285 |
+
"pyjamas": "pajamas",
|
1286 |
+
"pzazz": "pizzazz",
|
1287 |
+
"quarrelled": "quarreled",
|
1288 |
+
"quarrelling": "quarreling",
|
1289 |
+
"radicalise": "radicalize",
|
1290 |
+
"radicalised": "radicalized",
|
1291 |
+
"radicalises": "radicalizes",
|
1292 |
+
"radicalising": "radicalizing",
|
1293 |
+
"rancour": "rancor",
|
1294 |
+
"randomise": "randomize",
|
1295 |
+
"randomised": "randomized",
|
1296 |
+
"randomises": "randomizes",
|
1297 |
+
"randomising": "randomizing",
|
1298 |
+
"rationalisation": "rationalization",
|
1299 |
+
"rationalisations": "rationalizations",
|
1300 |
+
"rationalise": "rationalize",
|
1301 |
+
"rationalised": "rationalized",
|
1302 |
+
"rationalises": "rationalizes",
|
1303 |
+
"rationalising": "rationalizing",
|
1304 |
+
"ravelled": "raveled",
|
1305 |
+
"ravelling": "raveling",
|
1306 |
+
"realisable": "realizable",
|
1307 |
+
"realisation": "realization",
|
1308 |
+
"realisations": "realizations",
|
1309 |
+
"realise": "realize",
|
1310 |
+
"realised": "realized",
|
1311 |
+
"realises": "realizes",
|
1312 |
+
"realising": "realizing",
|
1313 |
+
"recognisable": "recognizable",
|
1314 |
+
"recognisably": "recognizably",
|
1315 |
+
"recognisance": "recognizance",
|
1316 |
+
"recognise": "recognize",
|
1317 |
+
"recognised": "recognized",
|
1318 |
+
"recognises": "recognizes",
|
1319 |
+
"recognising": "recognizing",
|
1320 |
+
"reconnoitre": "reconnoiter",
|
1321 |
+
"reconnoitred": "reconnoitered",
|
1322 |
+
"reconnoitres": "reconnoiters",
|
1323 |
+
"reconnoitring": "reconnoitering",
|
1324 |
+
"refuelled": "refueled",
|
1325 |
+
"refuelling": "refueling",
|
1326 |
+
"regularisation": "regularization",
|
1327 |
+
"regularise": "regularize",
|
1328 |
+
"regularised": "regularized",
|
1329 |
+
"regularises": "regularizes",
|
1330 |
+
"regularising": "regularizing",
|
1331 |
+
"remodelled": "remodeled",
|
1332 |
+
"remodelling": "remodeling",
|
1333 |
+
"remould": "remold",
|
1334 |
+
"remoulded": "remolded",
|
1335 |
+
"remoulding": "remolding",
|
1336 |
+
"remoulds": "remolds",
|
1337 |
+
"reorganisation": "reorganization",
|
1338 |
+
"reorganisations": "reorganizations",
|
1339 |
+
"reorganise": "reorganize",
|
1340 |
+
"reorganised": "reorganized",
|
1341 |
+
"reorganises": "reorganizes",
|
1342 |
+
"reorganising": "reorganizing",
|
1343 |
+
"revelled": "reveled",
|
1344 |
+
"reveller": "reveler",
|
1345 |
+
"revellers": "revelers",
|
1346 |
+
"revelling": "reveling",
|
1347 |
+
"revitalise": "revitalize",
|
1348 |
+
"revitalised": "revitalized",
|
1349 |
+
"revitalises": "revitalizes",
|
1350 |
+
"revitalising": "revitalizing",
|
1351 |
+
"revolutionise": "revolutionize",
|
1352 |
+
"revolutionised": "revolutionized",
|
1353 |
+
"revolutionises": "revolutionizes",
|
1354 |
+
"revolutionising": "revolutionizing",
|
1355 |
+
"rhapsodise": "rhapsodize",
|
1356 |
+
"rhapsodised": "rhapsodized",
|
1357 |
+
"rhapsodises": "rhapsodizes",
|
1358 |
+
"rhapsodising": "rhapsodizing",
|
1359 |
+
"rigour": "rigor",
|
1360 |
+
"rigours": "rigors",
|
1361 |
+
"ritualised": "ritualized",
|
1362 |
+
"rivalled": "rivaled",
|
1363 |
+
"rivalling": "rivaling",
|
1364 |
+
"romanticise": "romanticize",
|
1365 |
+
"romanticised": "romanticized",
|
1366 |
+
"romanticises": "romanticizes",
|
1367 |
+
"romanticising": "romanticizing",
|
1368 |
+
"rumour": "rumor",
|
1369 |
+
"rumoured": "rumored",
|
1370 |
+
"rumours": "rumors",
|
1371 |
+
"sabre": "saber",
|
1372 |
+
"sabres": "sabers",
|
1373 |
+
"saltpetre": "saltpeter",
|
1374 |
+
"sanitise": "sanitize",
|
1375 |
+
"sanitised": "sanitized",
|
1376 |
+
"sanitises": "sanitizes",
|
1377 |
+
"sanitising": "sanitizing",
|
1378 |
+
"satirise": "satirize",
|
1379 |
+
"satirised": "satirized",
|
1380 |
+
"satirises": "satirizes",
|
1381 |
+
"satirising": "satirizing",
|
1382 |
+
"saviour": "savior",
|
1383 |
+
"saviours": "saviors",
|
1384 |
+
"savour": "savor",
|
1385 |
+
"savoured": "savored",
|
1386 |
+
"savouries": "savories",
|
1387 |
+
"savouring": "savoring",
|
1388 |
+
"savours": "savors",
|
1389 |
+
"savoury": "savory",
|
1390 |
+
"scandalise": "scandalize",
|
1391 |
+
"scandalised": "scandalized",
|
1392 |
+
"scandalises": "scandalizes",
|
1393 |
+
"scandalising": "scandalizing",
|
1394 |
+
"sceptic": "skeptic",
|
1395 |
+
"sceptical": "skeptical",
|
1396 |
+
"sceptically": "skeptically",
|
1397 |
+
"scepticism": "skepticism",
|
1398 |
+
"sceptics": "skeptics",
|
1399 |
+
"sceptre": "scepter",
|
1400 |
+
"sceptres": "scepters",
|
1401 |
+
"scrutinise": "scrutinize",
|
1402 |
+
"scrutinised": "scrutinized",
|
1403 |
+
"scrutinises": "scrutinizes",
|
1404 |
+
"scrutinising": "scrutinizing",
|
1405 |
+
"secularisation": "secularization",
|
1406 |
+
"secularise": "secularize",
|
1407 |
+
"secularised": "secularized",
|
1408 |
+
"secularises": "secularizes",
|
1409 |
+
"secularising": "secularizing",
|
1410 |
+
"sensationalise": "sensationalize",
|
1411 |
+
"sensationalised": "sensationalized",
|
1412 |
+
"sensationalises": "sensationalizes",
|
1413 |
+
"sensationalising": "sensationalizing",
|
1414 |
+
"sensitise": "sensitize",
|
1415 |
+
"sensitised": "sensitized",
|
1416 |
+
"sensitises": "sensitizes",
|
1417 |
+
"sensitising": "sensitizing",
|
1418 |
+
"sentimentalise": "sentimentalize",
|
1419 |
+
"sentimentalised": "sentimentalized",
|
1420 |
+
"sentimentalises": "sentimentalizes",
|
1421 |
+
"sentimentalising": "sentimentalizing",
|
1422 |
+
"sepulchre": "sepulcher",
|
1423 |
+
"sepulchres": "sepulchers",
|
1424 |
+
"serialisation": "serialization",
|
1425 |
+
"serialisations": "serializations",
|
1426 |
+
"serialise": "serialize",
|
1427 |
+
"serialised": "serialized",
|
1428 |
+
"serialises": "serializes",
|
1429 |
+
"serialising": "serializing",
|
1430 |
+
"sermonise": "sermonize",
|
1431 |
+
"sermonised": "sermonized",
|
1432 |
+
"sermonises": "sermonizes",
|
1433 |
+
"sermonising": "sermonizing",
|
1434 |
+
"sheikh": "sheik",
|
1435 |
+
"shovelled": "shoveled",
|
1436 |
+
"shovelling": "shoveling",
|
1437 |
+
"shrivelled": "shriveled",
|
1438 |
+
"shrivelling": "shriveling",
|
1439 |
+
"signalise": "signalize",
|
1440 |
+
"signalised": "signalized",
|
1441 |
+
"signalises": "signalizes",
|
1442 |
+
"signalising": "signalizing",
|
1443 |
+
"signalled": "signaled",
|
1444 |
+
"signalling": "signaling",
|
1445 |
+
"smoulder": "smolder",
|
1446 |
+
"smouldered": "smoldered",
|
1447 |
+
"smouldering": "smoldering",
|
1448 |
+
"smoulders": "smolders",
|
1449 |
+
"snivelled": "sniveled",
|
1450 |
+
"snivelling": "sniveling",
|
1451 |
+
"snorkelled": "snorkeled",
|
1452 |
+
"snorkelling": "snorkeling",
|
1453 |
+
"snowplough": "snowplow",
|
1454 |
+
"snowploughs": "snowplow",
|
1455 |
+
"socialisation": "socialization",
|
1456 |
+
"socialise": "socialize",
|
1457 |
+
"socialised": "socialized",
|
1458 |
+
"socialises": "socializes",
|
1459 |
+
"socialising": "socializing",
|
1460 |
+
"sodomise": "sodomize",
|
1461 |
+
"sodomised": "sodomized",
|
1462 |
+
"sodomises": "sodomizes",
|
1463 |
+
"sodomising": "sodomizing",
|
1464 |
+
"solemnise": "solemnize",
|
1465 |
+
"solemnised": "solemnized",
|
1466 |
+
"solemnises": "solemnizes",
|
1467 |
+
"solemnising": "solemnizing",
|
1468 |
+
"sombre": "somber",
|
1469 |
+
"specialisation": "specialization",
|
1470 |
+
"specialisations": "specializations",
|
1471 |
+
"specialise": "specialize",
|
1472 |
+
"specialised": "specialized",
|
1473 |
+
"specialises": "specializes",
|
1474 |
+
"specialising": "specializing",
|
1475 |
+
"spectre": "specter",
|
1476 |
+
"spectres": "specters",
|
1477 |
+
"spiralled": "spiraled",
|
1478 |
+
"spiralling": "spiraling",
|
1479 |
+
"splendour": "splendor",
|
1480 |
+
"splendours": "splendors",
|
1481 |
+
"squirrelled": "squirreled",
|
1482 |
+
"squirrelling": "squirreling",
|
1483 |
+
"stabilisation": "stabilization",
|
1484 |
+
"stabilise": "stabilize",
|
1485 |
+
"stabilised": "stabilized",
|
1486 |
+
"stabiliser": "stabilizer",
|
1487 |
+
"stabilisers": "stabilizers",
|
1488 |
+
"stabilises": "stabilizes",
|
1489 |
+
"stabilising": "stabilizing",
|
1490 |
+
"standardisation": "standardization",
|
1491 |
+
"standardise": "standardize",
|
1492 |
+
"standardised": "standardized",
|
1493 |
+
"standardises": "standardizes",
|
1494 |
+
"standardising": "standardizing",
|
1495 |
+
"stencilled": "stenciled",
|
1496 |
+
"stencilling": "stenciling",
|
1497 |
+
"sterilisation": "sterilization",
|
1498 |
+
"sterilisations": "sterilizations",
|
1499 |
+
"sterilise": "sterilize",
|
1500 |
+
"sterilised": "sterilized",
|
1501 |
+
"steriliser": "sterilizer",
|
1502 |
+
"sterilisers": "sterilizers",
|
1503 |
+
"sterilises": "sterilizes",
|
1504 |
+
"sterilising": "sterilizing",
|
1505 |
+
"stigmatisation": "stigmatization",
|
1506 |
+
"stigmatise": "stigmatize",
|
1507 |
+
"stigmatised": "stigmatized",
|
1508 |
+
"stigmatises": "stigmatizes",
|
1509 |
+
"stigmatising": "stigmatizing",
|
1510 |
+
"storey": "story",
|
1511 |
+
"storeys": "stories",
|
1512 |
+
"subsidisation": "subsidization",
|
1513 |
+
"subsidise": "subsidize",
|
1514 |
+
"subsidised": "subsidized",
|
1515 |
+
"subsidiser": "subsidizer",
|
1516 |
+
"subsidisers": "subsidizers",
|
1517 |
+
"subsidises": "subsidizes",
|
1518 |
+
"subsidising": "subsidizing",
|
1519 |
+
"succour": "succor",
|
1520 |
+
"succoured": "succored",
|
1521 |
+
"succouring": "succoring",
|
1522 |
+
"succours": "succors",
|
1523 |
+
"sulphate": "sulfate",
|
1524 |
+
"sulphates": "sulfates",
|
1525 |
+
"sulphide": "sulfide",
|
1526 |
+
"sulphides": "sulfides",
|
1527 |
+
"sulphur": "sulfur",
|
1528 |
+
"sulphurous": "sulfurous",
|
1529 |
+
"summarise": "summarize",
|
1530 |
+
"summarised": "summarized",
|
1531 |
+
"summarises": "summarizes",
|
1532 |
+
"summarising": "summarizing",
|
1533 |
+
"swivelled": "swiveled",
|
1534 |
+
"swivelling": "swiveling",
|
1535 |
+
"symbolise": "symbolize",
|
1536 |
+
"symbolised": "symbolized",
|
1537 |
+
"symbolises": "symbolizes",
|
1538 |
+
"symbolising": "symbolizing",
|
1539 |
+
"sympathise": "sympathize",
|
1540 |
+
"sympathised": "sympathized",
|
1541 |
+
"sympathiser": "sympathizer",
|
1542 |
+
"sympathisers": "sympathizers",
|
1543 |
+
"sympathises": "sympathizes",
|
1544 |
+
"sympathising": "sympathizing",
|
1545 |
+
"synchronisation": "synchronization",
|
1546 |
+
"synchronise": "synchronize",
|
1547 |
+
"synchronised": "synchronized",
|
1548 |
+
"synchronises": "synchronizes",
|
1549 |
+
"synchronising": "synchronizing",
|
1550 |
+
"synthesise": "synthesize",
|
1551 |
+
"synthesised": "synthesized",
|
1552 |
+
"synthesiser": "synthesizer",
|
1553 |
+
"synthesisers": "synthesizers",
|
1554 |
+
"synthesises": "synthesizes",
|
1555 |
+
"synthesising": "synthesizing",
|
1556 |
+
"syphon": "siphon",
|
1557 |
+
"syphoned": "siphoned",
|
1558 |
+
"syphoning": "siphoning",
|
1559 |
+
"syphons": "siphons",
|
1560 |
+
"systematisation": "systematization",
|
1561 |
+
"systematise": "systematize",
|
1562 |
+
"systematised": "systematized",
|
1563 |
+
"systematises": "systematizes",
|
1564 |
+
"systematising": "systematizing",
|
1565 |
+
"tantalise": "tantalize",
|
1566 |
+
"tantalised": "tantalized",
|
1567 |
+
"tantalises": "tantalizes",
|
1568 |
+
"tantalising": "tantalizing",
|
1569 |
+
"tantalisingly": "tantalizingly",
|
1570 |
+
"tasselled": "tasseled",
|
1571 |
+
"technicolour": "technicolor",
|
1572 |
+
"temporise": "temporize",
|
1573 |
+
"temporised": "temporized",
|
1574 |
+
"temporises": "temporizes",
|
1575 |
+
"temporising": "temporizing",
|
1576 |
+
"tenderise": "tenderize",
|
1577 |
+
"tenderised": "tenderized",
|
1578 |
+
"tenderises": "tenderizes",
|
1579 |
+
"tenderising": "tenderizing",
|
1580 |
+
"terrorise": "terrorize",
|
1581 |
+
"terrorised": "terrorized",
|
1582 |
+
"terrorises": "terrorizes",
|
1583 |
+
"terrorising": "terrorizing",
|
1584 |
+
"theatre": "theater",
|
1585 |
+
"theatregoer": "theatergoer",
|
1586 |
+
"theatregoers": "theatergoers",
|
1587 |
+
"theatres": "theaters",
|
1588 |
+
"theorise": "theorize",
|
1589 |
+
"theorised": "theorized",
|
1590 |
+
"theorises": "theorizes",
|
1591 |
+
"theorising": "theorizing",
|
1592 |
+
"tonne": "ton",
|
1593 |
+
"tonnes": "tons",
|
1594 |
+
"towelled": "toweled",
|
1595 |
+
"towelling": "toweling",
|
1596 |
+
"toxaemia": "toxemia",
|
1597 |
+
"tranquillise": "tranquilize",
|
1598 |
+
"tranquillised": "tranquilized",
|
1599 |
+
"tranquilliser": "tranquilizer",
|
1600 |
+
"tranquillisers": "tranquilizers",
|
1601 |
+
"tranquillises": "tranquilizes",
|
1602 |
+
"tranquillising": "tranquilizing",
|
1603 |
+
"tranquillity": "tranquility",
|
1604 |
+
"tranquillize": "tranquilize",
|
1605 |
+
"tranquillized": "tranquilized",
|
1606 |
+
"tranquillizer": "tranquilizer",
|
1607 |
+
"tranquillizers": "tranquilizers",
|
1608 |
+
"tranquillizes": "tranquilizes",
|
1609 |
+
"tranquillizing": "tranquilizing",
|
1610 |
+
"tranquilly": "tranquility",
|
1611 |
+
"transistorised": "transistorized",
|
1612 |
+
"traumatise": "traumatize",
|
1613 |
+
"traumatised": "traumatized",
|
1614 |
+
"traumatises": "traumatizes",
|
1615 |
+
"traumatising": "traumatizing",
|
1616 |
+
"travelled": "traveled",
|
1617 |
+
"traveller": "traveler",
|
1618 |
+
"travellers": "travelers",
|
1619 |
+
"travelling": "traveling",
|
1620 |
+
"travelog": "travelogue",
|
1621 |
+
"travelogs": "travelogues",
|
1622 |
+
"trialled": "trialed",
|
1623 |
+
"trialling": "trialing",
|
1624 |
+
"tricolour": "tricolor",
|
1625 |
+
"tricolours": "tricolors",
|
1626 |
+
"trivialise": "trivialize",
|
1627 |
+
"trivialised": "trivialized",
|
1628 |
+
"trivialises": "trivializes",
|
1629 |
+
"trivialising": "trivializing",
|
1630 |
+
"tumour": "tumor",
|
1631 |
+
"tumours": "tumors",
|
1632 |
+
"tunnelled": "tunneled",
|
1633 |
+
"tunnelling": "tunneling",
|
1634 |
+
"tyrannise": "tyrannize",
|
1635 |
+
"tyrannised": "tyrannized",
|
1636 |
+
"tyrannises": "tyrannizes",
|
1637 |
+
"tyrannising": "tyrannizing",
|
1638 |
+
"tyre": "tire",
|
1639 |
+
"tyres": "tires",
|
1640 |
+
"unauthorised": "unauthorized",
|
1641 |
+
"uncivilised": "uncivilized",
|
1642 |
+
"underutilised": "underutilized",
|
1643 |
+
"unequalled": "unequaled",
|
1644 |
+
"unfavourable": "unfavorable",
|
1645 |
+
"unfavourably": "unfavorably",
|
1646 |
+
"unionisation": "unionization",
|
1647 |
+
"unionise": "unionize",
|
1648 |
+
"unionised": "unionized",
|
1649 |
+
"unionises": "unionizes",
|
1650 |
+
"unionising": "unionizing",
|
1651 |
+
"unorganised": "unorganized",
|
1652 |
+
"unravelled": "unraveled",
|
1653 |
+
"unravelling": "unraveling",
|
1654 |
+
"unrecognisable": "unrecognizable",
|
1655 |
+
"unrecognised": "unrecognized",
|
1656 |
+
"unrivalled": "unrivaled",
|
1657 |
+
"unsavoury": "unsavory",
|
1658 |
+
"untrammelled": "untrammeled",
|
1659 |
+
"urbanisation": "urbanization",
|
1660 |
+
"urbanise": "urbanize",
|
1661 |
+
"urbanised": "urbanized",
|
1662 |
+
"urbanises": "urbanizes",
|
1663 |
+
"urbanising": "urbanizing",
|
1664 |
+
"utilisable": "utilizable",
|
1665 |
+
"utilisation": "utilization",
|
1666 |
+
"utilise": "utilize",
|
1667 |
+
"utilised": "utilized",
|
1668 |
+
"utilises": "utilizes",
|
1669 |
+
"utilising": "utilizing",
|
1670 |
+
"valour": "valor",
|
1671 |
+
"vandalise": "vandalize",
|
1672 |
+
"vandalised": "vandalized",
|
1673 |
+
"vandalises": "vandalizes",
|
1674 |
+
"vandalising": "vandalizing",
|
1675 |
+
"vaporisation": "vaporization",
|
1676 |
+
"vaporise": "vaporize",
|
1677 |
+
"vaporised": "vaporized",
|
1678 |
+
"vaporises": "vaporizes",
|
1679 |
+
"vaporising": "vaporizing",
|
1680 |
+
"vapour": "vapor",
|
1681 |
+
"vapours": "vapors",
|
1682 |
+
"verbalise": "verbalize",
|
1683 |
+
"verbalised": "verbalized",
|
1684 |
+
"verbalises": "verbalizes",
|
1685 |
+
"verbalising": "verbalizing",
|
1686 |
+
"victimisation": "victimization",
|
1687 |
+
"victimise": "victimize",
|
1688 |
+
"victimised": "victimized",
|
1689 |
+
"victimises": "victimizes",
|
1690 |
+
"victimising": "victimizing",
|
1691 |
+
"videodisc": "videodisk",
|
1692 |
+
"videodiscs": "videodisks",
|
1693 |
+
"vigour": "vigor",
|
1694 |
+
"visualisation": "visualization",
|
1695 |
+
"visualisations": "visualizations",
|
1696 |
+
"visualise": "visualize",
|
1697 |
+
"visualised": "visualized",
|
1698 |
+
"visualises": "visualizes",
|
1699 |
+
"visualising": "visualizing",
|
1700 |
+
"vocalisation": "vocalization",
|
1701 |
+
"vocalisations": "vocalizations",
|
1702 |
+
"vocalise": "vocalize",
|
1703 |
+
"vocalised": "vocalized",
|
1704 |
+
"vocalises": "vocalizes",
|
1705 |
+
"vocalising": "vocalizing",
|
1706 |
+
"vulcanised": "vulcanized",
|
1707 |
+
"vulgarisation": "vulgarization",
|
1708 |
+
"vulgarise": "vulgarize",
|
1709 |
+
"vulgarised": "vulgarized",
|
1710 |
+
"vulgarises": "vulgarizes",
|
1711 |
+
"vulgarising": "vulgarizing",
|
1712 |
+
"waggon": "wagon",
|
1713 |
+
"waggons": "wagons",
|
1714 |
+
"watercolour": "watercolor",
|
1715 |
+
"watercolours": "watercolors",
|
1716 |
+
"weaselled": "weaseled",
|
1717 |
+
"weaselling": "weaseling",
|
1718 |
+
"westernisation": "westernization",
|
1719 |
+
"westernise": "westernize",
|
1720 |
+
"westernised": "westernized",
|
1721 |
+
"westernises": "westernizes",
|
1722 |
+
"westernising": "westernizing",
|
1723 |
+
"womanise": "womanize",
|
1724 |
+
"womanised": "womanized",
|
1725 |
+
"womaniser": "womanizer",
|
1726 |
+
"womanisers": "womanizers",
|
1727 |
+
"womanises": "womanizes",
|
1728 |
+
"womanising": "womanizing",
|
1729 |
+
"woollen": "woolen",
|
1730 |
+
"woollens": "woolens",
|
1731 |
+
"woollies": "woolies",
|
1732 |
+
"woolly": "wooly",
|
1733 |
+
"worshipped": "worshiped",
|
1734 |
+
"worshipper": "worshiper",
|
1735 |
+
"worshipping": "worshiping",
|
1736 |
+
"yodelled": "yodeled",
|
1737 |
+
"yodelling": "yodeling",
|
1738 |
+
"yoghourt": "yogurt",
|
1739 |
+
"yoghourts": "yogurts",
|
1740 |
+
"yoghurt": "yogurt",
|
1741 |
+
"yoghurts": "yogurts"
|
1742 |
+
}
|
preprocessor_config.json
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"chunk_length": 30,
|
3 |
+
"feature_extractor_type": "WhisperFeatureExtractor",
|
4 |
+
"feature_size": 80,
|
5 |
+
"hop_length": 160,
|
6 |
+
"n_fft": 400,
|
7 |
+
"n_samples": 480000,
|
8 |
+
"nb_max_frames": 3000,
|
9 |
+
"padding_side": "right",
|
10 |
+
"padding_value": 0.0,
|
11 |
+
"processor_class": "WhisperProcessor",
|
12 |
+
"return_attention_mask": false,
|
13 |
+
"sampling_rate": 16000
|
14 |
+
}
|
run.sh
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
python run_finetuning.py \
|
4 |
+
--model_name_or_path "openai/whisper-large-v2" \
|
5 |
+
--dataset_name "sanchit-gandhi/librispeech-data" \
|
6 |
+
--dataset_config_name "all" \
|
7 |
+
--train_split_name "train.clean.100+train.clean.360+train.other.500" \
|
8 |
+
--eval_split_name "validation.clean" \
|
9 |
+
--text_column_name "text" \
|
10 |
+
--cache_dir "/home/sanchitgandhi/cache" \
|
11 |
+
--dataset_cache_dir "/home/sanchitgandhi/cache" \
|
12 |
+
--output_dir "./" \
|
13 |
+
--wandb_name "large-v2-ft-ls" \
|
14 |
+
--wandb_dir "/home/sanchitgandhi/cache" \
|
15 |
+
--wandb_project "flax-whisper-librispeech" \
|
16 |
+
--per_device_train_batch_size 4 \
|
17 |
+
--per_device_eval_batch_size 16 \
|
18 |
+
--dtype "bfloat16" \
|
19 |
+
--optim "adafactor" \
|
20 |
+
--learning_rate 1e-4 \
|
21 |
+
--warmup_steps 500 \
|
22 |
+
--do_train \
|
23 |
+
--do_eval \
|
24 |
+
--num_train_epochs 10 \
|
25 |
+
--preprocessing_num_workers 16 \
|
26 |
+
--dataloader_num_workers 64 \
|
27 |
+
--logging_steps 25 \
|
28 |
+
--use_scan \
|
29 |
+
--gradient_checkpointing \
|
30 |
+
--overwrite_output_dir \
|
31 |
+
--predict_with_generate \
|
32 |
+
--push_to_hub
|
33 |
+
|
run_finetuning.py
ADDED
@@ -0,0 +1,1111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the Whisper model for sequence to sequence speech recognition.
|
18 |
+
"""
|
19 |
+
# You can also adapt this script for your own speech recognition task. Pointers for this are left as comments.
|
20 |
+
|
21 |
+
import logging
|
22 |
+
import os
|
23 |
+
import string
|
24 |
+
import sys
|
25 |
+
import time
|
26 |
+
from dataclasses import dataclass, field
|
27 |
+
from functools import partial
|
28 |
+
from pathlib import Path
|
29 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
30 |
+
|
31 |
+
import datasets
|
32 |
+
import evaluate
|
33 |
+
import flax
|
34 |
+
import jax
|
35 |
+
import jax.numpy as jnp
|
36 |
+
import numpy as np
|
37 |
+
import optax
|
38 |
+
import transformers
|
39 |
+
from datasets import Dataset, DatasetDict, load_dataset
|
40 |
+
from flax import jax_utils, traverse_util
|
41 |
+
from flax.jax_utils import pad_shard_unpad, unreplicate
|
42 |
+
from flax.training import train_state
|
43 |
+
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key
|
44 |
+
from huggingface_hub import Repository, create_repo
|
45 |
+
from torch.utils.data import DataLoader
|
46 |
+
from tqdm import tqdm
|
47 |
+
from transformers import (
|
48 |
+
AutoConfig,
|
49 |
+
AutoFeatureExtractor,
|
50 |
+
AutoProcessor,
|
51 |
+
AutoTokenizer,
|
52 |
+
HfArgumentParser,
|
53 |
+
Seq2SeqTrainingArguments,
|
54 |
+
is_tensorboard_available,
|
55 |
+
is_wandb_available,
|
56 |
+
)
|
57 |
+
from transformers.file_utils import get_full_repo_name
|
58 |
+
from transformers.models.whisper.english_normalizer import EnglishTextNormalizer
|
59 |
+
from transformers.utils import check_min_version, send_example_telemetry
|
60 |
+
from transformers.utils.versions import require_version
|
61 |
+
|
62 |
+
from distil_whisper import FlaxWhisperForConditionalGeneration
|
63 |
+
|
64 |
+
|
65 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
66 |
+
check_min_version("4.27.0.dev0")
|
67 |
+
|
68 |
+
require_version(
|
69 |
+
"datasets>=1.18.0",
|
70 |
+
"To fix: pip install -r examples/flax/speech-recogintion/requirements.txt",
|
71 |
+
)
|
72 |
+
|
73 |
+
logger = logging.getLogger(__name__)
|
74 |
+
|
75 |
+
|
76 |
+
@flax.struct.dataclass
|
77 |
+
class ModelArguments:
|
78 |
+
"""
|
79 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
80 |
+
"""
|
81 |
+
|
82 |
+
model_name_or_path: str = field(
|
83 |
+
metadata={"help": ("Path to pretrained model or model identifier from huggingface.co/models")}
|
84 |
+
)
|
85 |
+
config_name: Optional[str] = field(
|
86 |
+
default=None,
|
87 |
+
metadata={"help": "Pretrained config name or path if not the same as model_name"},
|
88 |
+
)
|
89 |
+
tokenizer_name: Optional[str] = field(
|
90 |
+
default=None,
|
91 |
+
metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"},
|
92 |
+
)
|
93 |
+
feature_extractor_name: Optional[str] = field(
|
94 |
+
default=None,
|
95 |
+
metadata={"help": "feature extractor name or path if not the same as model_name"},
|
96 |
+
)
|
97 |
+
cache_dir: Optional[str] = field(
|
98 |
+
default=None,
|
99 |
+
metadata={"help": ("Where to store the pretrained models downloaded from huggingface.co")},
|
100 |
+
)
|
101 |
+
use_fast_tokenizer: bool = field(
|
102 |
+
default=True,
|
103 |
+
metadata={"help": ("Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.")},
|
104 |
+
)
|
105 |
+
model_revision: str = field(
|
106 |
+
default="main",
|
107 |
+
metadata={"help": ("The specific model version to use (can be a branch name, tag name or commit id).")},
|
108 |
+
)
|
109 |
+
use_auth_token: bool = field(
|
110 |
+
default=False,
|
111 |
+
metadata={
|
112 |
+
"help": (
|
113 |
+
"Will use the token generated when running `transformers-cli login`"
|
114 |
+
" (necessary to use this script with private models)."
|
115 |
+
)
|
116 |
+
},
|
117 |
+
)
|
118 |
+
dtype: Optional[str] = field(
|
119 |
+
default="float32",
|
120 |
+
metadata={
|
121 |
+
"help": (
|
122 |
+
"Floating-point format in which the model weights should be initialized"
|
123 |
+
" and trained. Choose one of `[float32, float16, bfloat16]`."
|
124 |
+
)
|
125 |
+
},
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
@flax.struct.dataclass
|
130 |
+
class DataTrainingArguments:
|
131 |
+
"""
|
132 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
133 |
+
"""
|
134 |
+
|
135 |
+
dataset_name: str = field(
|
136 |
+
default=None,
|
137 |
+
metadata={"help": "The name of the dataset to use (via the datasets library)."},
|
138 |
+
)
|
139 |
+
dataset_config_name: Optional[str] = field(
|
140 |
+
default=None,
|
141 |
+
metadata={"help": ("The configuration name of the dataset to use (via the datasets library).")},
|
142 |
+
)
|
143 |
+
dataset_cache_dir: Optional[str] = field(
|
144 |
+
default=None,
|
145 |
+
metadata={"help": "Path to cache directory for saving and loading datasets"},
|
146 |
+
)
|
147 |
+
overwrite_cache: bool = field(
|
148 |
+
default=False,
|
149 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
150 |
+
)
|
151 |
+
preprocessing_num_workers: Optional[int] = field(
|
152 |
+
default=None,
|
153 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
154 |
+
)
|
155 |
+
max_train_samples: Optional[int] = field(
|
156 |
+
default=None,
|
157 |
+
metadata={
|
158 |
+
"help": (
|
159 |
+
"For debugging purposes or quicker training, truncate the number of"
|
160 |
+
" training examples to this value if set."
|
161 |
+
)
|
162 |
+
},
|
163 |
+
)
|
164 |
+
max_eval_samples: Optional[int] = field(
|
165 |
+
default=None,
|
166 |
+
metadata={
|
167 |
+
"help": (
|
168 |
+
"For debugging purposes or quicker training, truncate the number of"
|
169 |
+
" evaluation examples to this value if set."
|
170 |
+
)
|
171 |
+
},
|
172 |
+
)
|
173 |
+
audio_column_name: str = field(
|
174 |
+
default="audio",
|
175 |
+
metadata={"help": ("The name of the dataset column containing the audio data. Defaults to 'audio'")},
|
176 |
+
)
|
177 |
+
text_column_name: str = field(
|
178 |
+
default="whisper_transcript",
|
179 |
+
metadata={
|
180 |
+
"help": (
|
181 |
+
"The name of the dataset column containing the text data. Defaults to"
|
182 |
+
" 'whisper_transcript'which is the pseudo-labelled Whisper"
|
183 |
+
" transcription data."
|
184 |
+
)
|
185 |
+
},
|
186 |
+
)
|
187 |
+
max_duration_in_seconds: float = field(
|
188 |
+
default=30.0,
|
189 |
+
metadata={"help": ("Filter audio files that are longer than `max_duration_in_seconds` seconds")},
|
190 |
+
)
|
191 |
+
min_duration_in_seconds: float = field(
|
192 |
+
default=0.0,
|
193 |
+
metadata={"help": ("Filter audio files that are shorter than `min_duration_in_seconds` seconds")},
|
194 |
+
)
|
195 |
+
max_label_length: int = field(
|
196 |
+
default=128,
|
197 |
+
metadata={"help": "Truncate transcriptions that are longer `max_label_length` tokens."},
|
198 |
+
)
|
199 |
+
pad_target_to_multiple_of: Optional[int] = field(
|
200 |
+
default=None,
|
201 |
+
metadata={
|
202 |
+
"help": (
|
203 |
+
"If set will pad the target sequence to a multiple of the provided"
|
204 |
+
" value. This is important to avoid triggering recompilations on TPU."
|
205 |
+
" If unspecified, will default to padding the targets to max length."
|
206 |
+
)
|
207 |
+
},
|
208 |
+
)
|
209 |
+
preprocessing_only: bool = field(
|
210 |
+
default=False,
|
211 |
+
metadata={
|
212 |
+
"help": (
|
213 |
+
"Whether to only do data preprocessing and skip training. This is"
|
214 |
+
" especially useful when data preprocessing errors out in distributed"
|
215 |
+
" training due to timeout. In this case, one should run the"
|
216 |
+
" preprocessing in a non-distributed setup with"
|
217 |
+
" `preprocessing_only=True` so that the cached datasets can"
|
218 |
+
" consequently be loaded in distributed training"
|
219 |
+
)
|
220 |
+
},
|
221 |
+
)
|
222 |
+
train_split_name: str = field(
|
223 |
+
default="train",
|
224 |
+
metadata={
|
225 |
+
"help": ("The name of the training data set split to use (via the datasets library). Defaults to 'train'")
|
226 |
+
},
|
227 |
+
)
|
228 |
+
eval_split_name: str = field(
|
229 |
+
default="validation",
|
230 |
+
metadata={
|
231 |
+
"help": (
|
232 |
+
"The name of the evaluation data set split to use (via the datasets"
|
233 |
+
" library). Defaults to 'validation'"
|
234 |
+
)
|
235 |
+
},
|
236 |
+
)
|
237 |
+
wandb_project: str = field(
|
238 |
+
default="distil-whisper",
|
239 |
+
metadata={"help": "The name of the wandb project."},
|
240 |
+
)
|
241 |
+
wandb_name: str = field(
|
242 |
+
default=None,
|
243 |
+
metadata={"help": "The name of the wandb run."},
|
244 |
+
)
|
245 |
+
wandb_job_type: str = field(
|
246 |
+
default="distil-whisper",
|
247 |
+
metadata={"help": "The name of the wandb job type."},
|
248 |
+
)
|
249 |
+
wandb_dir: str = field(
|
250 |
+
default=None,
|
251 |
+
metadata={"help": "The absolute path to save the wandb logs."},
|
252 |
+
)
|
253 |
+
save_code_to_wandb: bool = field(
|
254 |
+
default=False,
|
255 |
+
metadata={
|
256 |
+
"help": (
|
257 |
+
"Whether to save main script to wandb. This is valuable for improving"
|
258 |
+
" experimentreproducibility and to diff code across experiments in"
|
259 |
+
" the UI."
|
260 |
+
)
|
261 |
+
},
|
262 |
+
)
|
263 |
+
|
264 |
+
|
265 |
+
@dataclass
|
266 |
+
class FlaxSeq2SeqTrainingArguments(Seq2SeqTrainingArguments):
|
267 |
+
use_scan: Optional[bool] = field(
|
268 |
+
default=True,
|
269 |
+
metadata={
|
270 |
+
"help": (
|
271 |
+
"Whether or not to use `scan_with_axes` over the encoder and decoder"
|
272 |
+
" blocks. Using scan results in faster compile times and more efficient"
|
273 |
+
" memory use during training, since all of the layers in the"
|
274 |
+
" encoder/decoder are stacked, and we perform a lax.scan over the"
|
275 |
+
" stacked block to index each layer. However, it results in slower"
|
276 |
+
" inference time due to the overhead of stacking the layers this way."
|
277 |
+
" Thus, we always default to disabling scan for the inference step."
|
278 |
+
)
|
279 |
+
},
|
280 |
+
)
|
281 |
+
freeze_encoder: Optional[bool] = field(
|
282 |
+
default=False,
|
283 |
+
metadata={
|
284 |
+
"help": (
|
285 |
+
"Whether to freeze the entire encoder model. Only recommended when the"
|
286 |
+
" entire encoder has been copiedfrom the teacher model."
|
287 |
+
)
|
288 |
+
},
|
289 |
+
)
|
290 |
+
|
291 |
+
|
292 |
+
def shift_tokens_right(label_ids: np.array, decoder_start_token_id: int) -> np.ndarray:
|
293 |
+
"""
|
294 |
+
Shift label ids one token to the right.
|
295 |
+
"""
|
296 |
+
shifted_label_ids = np.zeros_like(label_ids)
|
297 |
+
shifted_label_ids[:, 1:] = label_ids[:, :-1]
|
298 |
+
shifted_label_ids[:, 0] = decoder_start_token_id
|
299 |
+
|
300 |
+
return shifted_label_ids
|
301 |
+
|
302 |
+
|
303 |
+
@flax.struct.dataclass
|
304 |
+
class FlaxDataCollatorSpeechSeq2SeqWithPadding:
|
305 |
+
"""
|
306 |
+
Data collator that will dynamically pad the inputs received.
|
307 |
+
Args:
|
308 |
+
processor ([`Wav2Vec2Processor`])
|
309 |
+
The processor used for proccessing the data.
|
310 |
+
decoder_start_token_id (:obj: `int`)
|
311 |
+
The begin-of-sentence of the decoder.
|
312 |
+
input_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
313 |
+
Select a strategy to pad the returned input sequences (according to the model's padding side and padding index)
|
314 |
+
among:
|
315 |
+
* :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
|
316 |
+
sequence if provided).
|
317 |
+
* :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
|
318 |
+
maximum acceptable input length for the model if that argument is not provided.
|
319 |
+
* :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
|
320 |
+
different lengths).
|
321 |
+
target_padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
|
322 |
+
Select a strategy to pad the returned target sequences (according to the model's padding side and padding index).
|
323 |
+
See above for details.
|
324 |
+
max_target_length (:obj:`int`, `optional`):
|
325 |
+
Maximum length of the ``labels`` of the returned list and optionally padding length (see above).
|
326 |
+
"""
|
327 |
+
|
328 |
+
processor: Any
|
329 |
+
decoder_start_token_id: int
|
330 |
+
input_padding: Union[bool, str] = "max_length"
|
331 |
+
target_padding: Union[bool, str] = "max_length"
|
332 |
+
max_target_length: Optional[int] = None
|
333 |
+
|
334 |
+
def __call__(self, features: List[Dict[str, Union[List[int], np.ndarray]]]) -> Dict[str, np.ndarray]:
|
335 |
+
# split inputs and labels since they have to be of different lengths and need
|
336 |
+
# different padding methods
|
337 |
+
model_input_name = self.processor.model_input_names[0]
|
338 |
+
|
339 |
+
# dataloader returns a list of features which we convert to a dict
|
340 |
+
input_features = {model_input_name: [feature[model_input_name] for feature in features]}
|
341 |
+
label_features = {"input_ids": [feature["labels"] for feature in features]}
|
342 |
+
|
343 |
+
# reformat list to dict and set to pytorch format
|
344 |
+
batch = self.processor.feature_extractor.pad(
|
345 |
+
input_features,
|
346 |
+
padding=self.input_padding,
|
347 |
+
return_tensors="np",
|
348 |
+
)
|
349 |
+
|
350 |
+
labels_batch = self.processor.tokenizer.pad(
|
351 |
+
label_features,
|
352 |
+
max_length=self.max_target_length,
|
353 |
+
padding=self.target_padding,
|
354 |
+
return_tensors="np",
|
355 |
+
)
|
356 |
+
|
357 |
+
# if bos token is appended in previous tokenization step,
|
358 |
+
# cut bos token here as it's append later anyways
|
359 |
+
labels = labels_batch["input_ids"]
|
360 |
+
if (labels[:, 0] == self.decoder_start_token_id).all().item():
|
361 |
+
labels = labels[:, 1:]
|
362 |
+
labels_batch.attention_mask = labels_batch.attention_mask[:, 1:]
|
363 |
+
|
364 |
+
decoder_input_ids = shift_tokens_right(labels, self.decoder_start_token_id)
|
365 |
+
|
366 |
+
# replace padding with -100 to ignore correctly when computing the loss
|
367 |
+
labels = np.ma.array(labels, mask=np.not_equal(labels_batch.attention_mask, 1))
|
368 |
+
labels = labels.filled(fill_value=-100)
|
369 |
+
|
370 |
+
batch["labels"] = labels
|
371 |
+
batch["decoder_input_ids"] = decoder_input_ids
|
372 |
+
|
373 |
+
return batch
|
374 |
+
|
375 |
+
|
376 |
+
def get_data_loader(
|
377 |
+
rng: jax.random.PRNGKey,
|
378 |
+
dataset: Dataset,
|
379 |
+
batch_size: int,
|
380 |
+
data_collator: FlaxDataCollatorSpeechSeq2SeqWithPadding,
|
381 |
+
shuffle: bool = True,
|
382 |
+
drop_last: bool = True,
|
383 |
+
dataloader_num_workers: int = 0,
|
384 |
+
pin_memory: bool = True,
|
385 |
+
) -> DataLoader:
|
386 |
+
"""
|
387 |
+
Returns batches of size `batch_size` from `dataset`. If `drop_last` is set to `False`, the final batch may be incomplete,
|
388 |
+
and range in size from 1 to `batch_size`. Shuffle batches if `shuffle` is `True`.
|
389 |
+
|
390 |
+
Args:
|
391 |
+
rng (List(int)): JAX rng for generating pseudo random numbers. Used if shuffling the dataset.
|
392 |
+
dataset (Dataset): dataset from which to load the data.
|
393 |
+
batch_size (int): how many samples per batch to load.
|
394 |
+
data_collator (FlaxDataCollatorSpeechSeq2SeqWithPadding, optional): merges a list of samples to form a
|
395 |
+
mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
|
396 |
+
shuffle (bool, optional): set to `True` to have the batches reshuffled.
|
397 |
+
drop_last (bool, optional): set to ``True`` to drop the last incomplete batch,
|
398 |
+
if the dataset size is not divisible by the batch size. If ``False`` and
|
399 |
+
the size of dataset is not divisible by the batch size, then the last batch
|
400 |
+
will be smaller. (default: ``False``)
|
401 |
+
dataloader_num_workers (int, optional): how many subprocesses to use for data
|
402 |
+
loading. ``0`` means that the data will be loaded in the main process.
|
403 |
+
(default: ``0``)
|
404 |
+
pin_memory (bool, optional): If ``True``, the data loader will copy Tensors
|
405 |
+
into device/CUDA pinned memory before returning them. If your data elements
|
406 |
+
are a custom type, or your :attr:`collate_fn` returns a batch that is a custom type,
|
407 |
+
see the example below.
|
408 |
+
|
409 |
+
"""
|
410 |
+
if shuffle:
|
411 |
+
batch_idx = jax.random.permutation(rng, len(dataset))
|
412 |
+
batch_idx = np.asarray(batch_idx)
|
413 |
+
dataset = dataset.select(batch_idx)
|
414 |
+
|
415 |
+
data_loader = DataLoader(
|
416 |
+
dataset,
|
417 |
+
batch_size=batch_size,
|
418 |
+
drop_last=drop_last,
|
419 |
+
pin_memory=pin_memory,
|
420 |
+
collate_fn=data_collator,
|
421 |
+
num_workers=dataloader_num_workers,
|
422 |
+
)
|
423 |
+
|
424 |
+
return data_loader
|
425 |
+
|
426 |
+
|
427 |
+
class TrainState(train_state.TrainState):
|
428 |
+
dropout_rng: jnp.ndarray
|
429 |
+
|
430 |
+
def replicate(self):
|
431 |
+
return jax_utils.replicate(self).replace(dropout_rng=shard_prng_key(self.dropout_rng))
|
432 |
+
|
433 |
+
|
434 |
+
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step, logging_steps):
|
435 |
+
summary_writer.scalar("train/time", train_time, step)
|
436 |
+
|
437 |
+
train_metrics = get_metrics(train_metrics)
|
438 |
+
for key, vals in train_metrics.items():
|
439 |
+
steps_arr = np.arange(0, step, logging_steps)[-len(vals) :]
|
440 |
+
tag = f"train/{key}"
|
441 |
+
for i, val in enumerate(vals):
|
442 |
+
summary_writer.scalar(tag, val, steps_arr[i])
|
443 |
+
|
444 |
+
for metric_name, value in eval_metrics.items():
|
445 |
+
summary_writer.scalar(f"eval/{metric_name}", value, step)
|
446 |
+
|
447 |
+
|
448 |
+
def write_wandb_metric(wandb_logger, metrics, train_time, step, prefix):
|
449 |
+
log_metrics = {}
|
450 |
+
for k, v in metrics.items():
|
451 |
+
log_metrics[f"{prefix}/{k}"] = v
|
452 |
+
log_metrics[f"{prefix}/time"] = train_time
|
453 |
+
wandb_logger.log(log_metrics, step)
|
454 |
+
|
455 |
+
|
456 |
+
def write_wandb_pred(wandb_logger, pred_str, label_str, prefix="eval", num_lines=100):
|
457 |
+
# convert str data to a wandb compatible format
|
458 |
+
if num_lines < len(pred_str):
|
459 |
+
str_data = [[label_str[i], pred_str[i]] for i in range(num_lines)]
|
460 |
+
else:
|
461 |
+
str_data = [[label_str[i], pred_str[i]] for i in range(len(pred_str))]
|
462 |
+
# log as a table with the appropriate headers
|
463 |
+
wandb_logger.log(
|
464 |
+
{f"{prefix}/predictions": wandb_logger.Table(columns=["label_str", "pred_str"], data=str_data)},
|
465 |
+
)
|
466 |
+
|
467 |
+
|
468 |
+
def create_learning_rate_fn(
|
469 |
+
num_train_steps: int, num_warmup_steps: int, learning_rate: float
|
470 |
+
) -> Callable[[int], jnp.array]:
|
471 |
+
"""Returns a linear warmup, linear_decay learning rate function."""
|
472 |
+
warmup_fn = optax.linear_schedule(init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps)
|
473 |
+
decay_fn = optax.linear_schedule(
|
474 |
+
init_value=learning_rate,
|
475 |
+
end_value=0,
|
476 |
+
transition_steps=num_train_steps - num_warmup_steps,
|
477 |
+
)
|
478 |
+
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
479 |
+
return schedule_fn
|
480 |
+
|
481 |
+
|
482 |
+
def main():
|
483 |
+
# 1. Parse input arguments
|
484 |
+
# See all possible arguments in src/transformers/training_args.py
|
485 |
+
# or by passing the --help flag to this script.
|
486 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
487 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, FlaxSeq2SeqTrainingArguments))
|
488 |
+
|
489 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
490 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
491 |
+
# let's parse it to get our arguments.
|
492 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
493 |
+
else:
|
494 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
495 |
+
|
496 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
497 |
+
# information sent is the one passed as arguments along with your JAX/Flax versions.
|
498 |
+
send_example_telemetry("run_flax_speech_recognition_seq2seq", model_args, data_args, framework="flax")
|
499 |
+
|
500 |
+
# 2. Setup logging
|
501 |
+
# Make one log on every process with the configuration for debugging.
|
502 |
+
logging.basicConfig(
|
503 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
504 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
505 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
506 |
+
)
|
507 |
+
# Set the verbosity to info of the Transformers logger.
|
508 |
+
# We only want one process per machine to log things on the screen.
|
509 |
+
logger.setLevel(logging.INFO if jax.process_index() == 0 else logging.ERROR)
|
510 |
+
if jax.process_index() == 0:
|
511 |
+
datasets.utils.logging.set_verbosity_warning()
|
512 |
+
transformers.utils.logging.set_verbosity_info()
|
513 |
+
else:
|
514 |
+
datasets.utils.logging.set_verbosity_error()
|
515 |
+
transformers.utils.logging.set_verbosity_error()
|
516 |
+
|
517 |
+
logger.info("Training/evaluation parameters %s", training_args)
|
518 |
+
|
519 |
+
# Check the output dir is valid
|
520 |
+
if (
|
521 |
+
os.path.exists(training_args.output_dir)
|
522 |
+
and os.listdir(training_args.output_dir)
|
523 |
+
and training_args.do_train
|
524 |
+
and not training_args.overwrite_output_dir
|
525 |
+
):
|
526 |
+
raise ValueError(
|
527 |
+
f"Output directory ({training_args.output_dir}) already exists and is not"
|
528 |
+
" empty.Use `--overwrite_output_dir` to overcome."
|
529 |
+
)
|
530 |
+
|
531 |
+
# Handle the repository creation
|
532 |
+
if training_args.push_to_hub:
|
533 |
+
if training_args.hub_model_id is None:
|
534 |
+
repo_name = get_full_repo_name(
|
535 |
+
Path(training_args.output_dir).absolute().name,
|
536 |
+
token=training_args.hub_token,
|
537 |
+
)
|
538 |
+
else:
|
539 |
+
repo_name = training_args.hub_model_id
|
540 |
+
create_repo(repo_name, exist_ok=True, token=training_args.hub_token)
|
541 |
+
repo = Repository(
|
542 |
+
training_args.output_dir,
|
543 |
+
clone_from=repo_name,
|
544 |
+
token=training_args.hub_token,
|
545 |
+
)
|
546 |
+
|
547 |
+
# 3. Load dataset
|
548 |
+
raw_datasets = DatasetDict()
|
549 |
+
|
550 |
+
if training_args.do_train:
|
551 |
+
raw_datasets["train"] = load_dataset(
|
552 |
+
data_args.dataset_name,
|
553 |
+
data_args.dataset_config_name,
|
554 |
+
split=data_args.train_split_name,
|
555 |
+
cache_dir=data_args.dataset_cache_dir,
|
556 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
557 |
+
num_proc=data_args.preprocessing_num_workers,
|
558 |
+
)
|
559 |
+
|
560 |
+
if training_args.do_eval:
|
561 |
+
raw_datasets["eval"] = load_dataset(
|
562 |
+
data_args.dataset_name,
|
563 |
+
data_args.dataset_config_name,
|
564 |
+
split=data_args.eval_split_name,
|
565 |
+
cache_dir=data_args.dataset_cache_dir,
|
566 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
567 |
+
num_proc=data_args.preprocessing_num_workers,
|
568 |
+
)
|
569 |
+
|
570 |
+
if not training_args.do_train and not training_args.do_eval:
|
571 |
+
raise ValueError(
|
572 |
+
"Cannot not train and not do evaluation. At least one of training or evaluation has to be performed."
|
573 |
+
)
|
574 |
+
|
575 |
+
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
576 |
+
raise ValueError(
|
577 |
+
f"--audio_column_name '{data_args.audio_column_name}' not found in dataset"
|
578 |
+
f" '{data_args.dataset_name}'. Make sure to set `--audio_column_name` to"
|
579 |
+
" the correct audio column - one of"
|
580 |
+
f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
|
581 |
+
)
|
582 |
+
|
583 |
+
if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names:
|
584 |
+
raise ValueError(
|
585 |
+
f"--text_column_name {data_args.text_column_name} not found in dataset"
|
586 |
+
f" '{data_args.dataset_name}'. Make sure to set `--text_column_name` to the"
|
587 |
+
" correct text column - one of"
|
588 |
+
f" {', '.join(next(iter(raw_datasets.values())).column_names)}."
|
589 |
+
)
|
590 |
+
|
591 |
+
# 5. Load pretrained model, tokenizer, and feature extractor
|
592 |
+
config = AutoConfig.from_pretrained(
|
593 |
+
(model_args.config_name if model_args.config_name else model_args.model_name_or_path),
|
594 |
+
cache_dir=model_args.cache_dir,
|
595 |
+
revision=model_args.model_revision,
|
596 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
597 |
+
)
|
598 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(
|
599 |
+
(model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path),
|
600 |
+
cache_dir=model_args.cache_dir,
|
601 |
+
revision=model_args.model_revision,
|
602 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
603 |
+
)
|
604 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
605 |
+
(model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path),
|
606 |
+
cache_dir=model_args.cache_dir,
|
607 |
+
use_fast=model_args.use_fast_tokenizer,
|
608 |
+
revision=model_args.model_revision,
|
609 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
610 |
+
)
|
611 |
+
|
612 |
+
model, params = FlaxWhisperForConditionalGeneration.from_pretrained(
|
613 |
+
model_args.model_name_or_path,
|
614 |
+
config=config,
|
615 |
+
dtype=getattr(jnp, model_args.dtype),
|
616 |
+
cache_dir=model_args.cache_dir,
|
617 |
+
revision=model_args.model_revision,
|
618 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
619 |
+
_do_init=False,
|
620 |
+
)
|
621 |
+
|
622 |
+
if model.config.decoder_start_token_id is None:
|
623 |
+
raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
|
624 |
+
|
625 |
+
# enable scan / gradient checkpointing if necessary
|
626 |
+
if training_args.use_scan:
|
627 |
+
model.enable_scan() # to enable scan in the nn.Module
|
628 |
+
params = model.convert_unroll_to_scan(params) # to convert the unrolled params to scan
|
629 |
+
|
630 |
+
if training_args.gradient_checkpointing:
|
631 |
+
model.enable_gradient_checkpointing() # to enable checkpointing in the nn.Module, there is no change to the params structure
|
632 |
+
|
633 |
+
if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual:
|
634 |
+
# We need to set the language and task ids for previously multilingual checkpoints
|
635 |
+
tokenizer.set_prefix_tokens(language="English", task="transcribe", predict_timestamps=False)
|
636 |
+
model.generation_config.forced_decoder_ids = tokenizer.get_decoder_prompt_ids(
|
637 |
+
language="English", task="transcribe", no_timestamps=True
|
638 |
+
)
|
639 |
+
|
640 |
+
# 6. Resample speech dataset: `datasets` takes care of automatically loading and resampling the audio,
|
641 |
+
# so we just need to set the correct target sampling rate.
|
642 |
+
raw_datasets = raw_datasets.cast_column(
|
643 |
+
data_args.audio_column_name,
|
644 |
+
datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate),
|
645 |
+
)
|
646 |
+
|
647 |
+
# 7. Preprocessing the datasets.
|
648 |
+
# We need to read the audio files as arrays and tokenize the targets.
|
649 |
+
max_input_length = int(data_args.max_duration_in_seconds * feature_extractor.sampling_rate)
|
650 |
+
min_input_length = int(data_args.min_duration_in_seconds * feature_extractor.sampling_rate)
|
651 |
+
max_label_length = (
|
652 |
+
data_args.max_label_length if data_args.max_label_length is not None else model.config.max_length
|
653 |
+
)
|
654 |
+
audio_column_name = data_args.audio_column_name
|
655 |
+
num_workers = data_args.preprocessing_num_workers
|
656 |
+
dataloader_num_workers = training_args.dataloader_num_workers
|
657 |
+
text_column_name = data_args.text_column_name
|
658 |
+
model_input_name = feature_extractor.model_input_names[0]
|
659 |
+
normalizer = EnglishTextNormalizer(tokenizer.english_spelling_normalizer)
|
660 |
+
|
661 |
+
if training_args.do_train and data_args.max_train_samples is not None:
|
662 |
+
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
663 |
+
|
664 |
+
if training_args.do_eval and data_args.max_eval_samples is not None:
|
665 |
+
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
666 |
+
|
667 |
+
def prepare_dataset(batch):
|
668 |
+
# process audio
|
669 |
+
sample = batch[audio_column_name]
|
670 |
+
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
|
671 |
+
# process audio length
|
672 |
+
batch[model_input_name] = inputs.get(model_input_name)[0]
|
673 |
+
batch["input_length"] = len(sample["array"])
|
674 |
+
|
675 |
+
# process targets
|
676 |
+
input_str = " " + batch[text_column_name].lower()
|
677 |
+
batch["labels"] = tokenizer(input_str).input_ids
|
678 |
+
return batch
|
679 |
+
|
680 |
+
vectorized_datasets = raw_datasets.map(
|
681 |
+
prepare_dataset,
|
682 |
+
remove_columns=next(iter(raw_datasets.values())).column_names,
|
683 |
+
num_proc=num_workers,
|
684 |
+
desc="preprocess train dataset",
|
685 |
+
)
|
686 |
+
|
687 |
+
# filter training data with inputs longer than max_input_length
|
688 |
+
def is_audio_in_length_range(length):
|
689 |
+
return min_input_length < length < max_input_length
|
690 |
+
|
691 |
+
vectorized_datasets = vectorized_datasets.filter(
|
692 |
+
is_audio_in_length_range,
|
693 |
+
num_proc=num_workers,
|
694 |
+
input_columns=["input_length"],
|
695 |
+
)
|
696 |
+
|
697 |
+
# filter training data with labels longer than max_label_length
|
698 |
+
def is_labels_in_length_range(labels):
|
699 |
+
return 0 < len(labels) < max_label_length
|
700 |
+
|
701 |
+
vectorized_datasets = vectorized_datasets.filter(
|
702 |
+
is_labels_in_length_range,
|
703 |
+
num_proc=num_workers,
|
704 |
+
input_columns=["labels"],
|
705 |
+
)
|
706 |
+
|
707 |
+
# for large datasets it is advised to run the preprocessing on a
|
708 |
+
# single machine first with `args.preprocessing_only` since there will mostly likely
|
709 |
+
# be a timeout when running the script in distributed mode.
|
710 |
+
# In a second step `args.preprocessing_only` can then be set to `False` to load the
|
711 |
+
# cached dataset
|
712 |
+
if data_args.preprocessing_only:
|
713 |
+
cache = {k: v.cache_files for k, v in vectorized_datasets.items()}
|
714 |
+
logger.info(f"Data preprocessing finished. Files cached at {cache}.")
|
715 |
+
return
|
716 |
+
|
717 |
+
# 8. Load Metric
|
718 |
+
metric = evaluate.load("wer")
|
719 |
+
all_punctuation = list(string.punctuation.replace("'", ""))
|
720 |
+
|
721 |
+
def compute_metrics(preds, labels):
|
722 |
+
# replace padded labels by the padding token
|
723 |
+
for idx in range(len(labels)):
|
724 |
+
labels[idx][labels[idx] == -100] = tokenizer.pad_token_id
|
725 |
+
|
726 |
+
pred_str = tokenizer.batch_decode(preds, skip_special_tokens=True)
|
727 |
+
# we do not want to group tokens when computing the metrics
|
728 |
+
label_str = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
729 |
+
|
730 |
+
# space punctuation for orthographic WER (c.f. ESB paper https://arxiv.org/abs/2210.13352)
|
731 |
+
spaced_pred_str = [
|
732 |
+
pred_str[i].replace(punctuation, "") for punctuation in all_punctuation for i in range(len(pred_str))
|
733 |
+
]
|
734 |
+
spaced_label_str = [
|
735 |
+
label_str[i].replace(punctuation, "") for punctuation in all_punctuation for i in range(len(label_str))
|
736 |
+
]
|
737 |
+
wer_ortho = 100 * metric.compute(predictions=spaced_pred_str, references=spaced_label_str)
|
738 |
+
|
739 |
+
# normalize everything and re-compute the WER
|
740 |
+
norm_pred_str = [normalizer(pred) for pred in pred_str]
|
741 |
+
norm_label_str = [normalizer(label) for label in label_str]
|
742 |
+
# filtering step to only evaluate the samples that correspond to non-zero normalized references:
|
743 |
+
norm_pred_str = [norm_pred_str[i] for i in range(len(norm_pred_str)) if len(norm_label_str[i]) > 0]
|
744 |
+
norm_label_str = [norm_label_str[i] for i in range(len(norm_label_str)) if len(norm_label_str[i]) > 0]
|
745 |
+
|
746 |
+
wer = 100 * metric.compute(predictions=norm_pred_str, references=norm_label_str)
|
747 |
+
|
748 |
+
return {"wer": wer, "wer_ortho": wer_ortho}, pred_str, label_str
|
749 |
+
|
750 |
+
# 9. Save feature extractor, tokenizer, config and generation config
|
751 |
+
feature_extractor.save_pretrained(training_args.output_dir)
|
752 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
753 |
+
config.save_pretrained(training_args.output_dir)
|
754 |
+
model.generation_config.save_pretrained(
|
755 |
+
training_args.output_dir
|
756 |
+
) # generation config stays bound to model to make it easy to jit
|
757 |
+
|
758 |
+
processor = AutoProcessor.from_pretrained(training_args.output_dir)
|
759 |
+
|
760 |
+
data_collator = FlaxDataCollatorSpeechSeq2SeqWithPadding(
|
761 |
+
processor=processor,
|
762 |
+
decoder_start_token_id=model.config.decoder_start_token_id,
|
763 |
+
input_padding="longest",
|
764 |
+
target_padding="max_length",
|
765 |
+
max_target_length=max_label_length,
|
766 |
+
)
|
767 |
+
|
768 |
+
# Enable tensorboard only on the master node
|
769 |
+
has_tensorboard = is_tensorboard_available()
|
770 |
+
if has_tensorboard and jax.process_index() == 0:
|
771 |
+
try:
|
772 |
+
from flax.metrics.tensorboard import SummaryWriter
|
773 |
+
|
774 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
775 |
+
except ImportError as ie:
|
776 |
+
has_tensorboard = False
|
777 |
+
logger.warning(
|
778 |
+
"Unable to display metrics through TensorBoard because some package" f" are not installed: {ie}"
|
779 |
+
)
|
780 |
+
else:
|
781 |
+
logger.warning(
|
782 |
+
"Unable to display metrics through TensorBoard because the package is not"
|
783 |
+
" installed: Please run `pip install tensorboard` to enable."
|
784 |
+
)
|
785 |
+
|
786 |
+
# Enable wandb only on the master node
|
787 |
+
has_wandb = is_wandb_available()
|
788 |
+
if has_wandb:
|
789 |
+
import wandb as wandb_logger
|
790 |
+
|
791 |
+
# Set up wandb run
|
792 |
+
if jax.process_index() == 0:
|
793 |
+
wandb_logger.init(
|
794 |
+
project=data_args.wandb_project,
|
795 |
+
name=data_args.wandb_name,
|
796 |
+
job_type=data_args.wandb_job_type,
|
797 |
+
dir=data_args.wandb_dir,
|
798 |
+
save_code=data_args.save_code_to_wandb,
|
799 |
+
)
|
800 |
+
else:
|
801 |
+
logger.warning("Wandb logging requires wandb to be installed. Run `pip install wandb` to enable.")
|
802 |
+
|
803 |
+
# Initialize our training
|
804 |
+
rng = jax.random.PRNGKey(training_args.seed)
|
805 |
+
rng, dropout_rng = jax.random.split(rng)
|
806 |
+
|
807 |
+
# Store some constant
|
808 |
+
num_epochs = int(training_args.num_train_epochs)
|
809 |
+
train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
|
810 |
+
per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
|
811 |
+
eval_batch_size = per_device_eval_batch_size * jax.device_count()
|
812 |
+
steps_per_epoch = len(vectorized_datasets["train"]) // train_batch_size
|
813 |
+
total_train_steps = steps_per_epoch * num_epochs
|
814 |
+
|
815 |
+
# Create learning rate schedule
|
816 |
+
linear_decay_lr_schedule_fn = create_learning_rate_fn(
|
817 |
+
total_train_steps,
|
818 |
+
training_args.warmup_steps,
|
819 |
+
training_args.learning_rate,
|
820 |
+
)
|
821 |
+
|
822 |
+
# We use Optax's "masking" functionality to not apply weight decay
|
823 |
+
# to bias and LayerNorm scale parameters. decay_mask_fn returns a
|
824 |
+
# mask boolean with the same structure as the parameters.
|
825 |
+
# The mask is True for parameters that should be decayed.
|
826 |
+
def decay_mask_fn(params):
|
827 |
+
flat_params = traverse_util.flatten_dict(params)
|
828 |
+
# find out all LayerNorm parameters
|
829 |
+
layer_norm_candidates = [
|
830 |
+
"layer_norm",
|
831 |
+
"self_attn_layer_norm",
|
832 |
+
"final_layer_norm",
|
833 |
+
"encoder_attn_layer_norm",
|
834 |
+
]
|
835 |
+
layer_norm_named_params = {
|
836 |
+
layer[-2:]
|
837 |
+
for layer_norm_name in layer_norm_candidates
|
838 |
+
for layer in flat_params.keys()
|
839 |
+
if layer_norm_name in "".join(layer).lower()
|
840 |
+
}
|
841 |
+
flat_mask = {path: path[-1] != "bias" and path[-2:] not in layer_norm_named_params for path in flat_params}
|
842 |
+
return traverse_util.unflatten_dict(flat_mask)
|
843 |
+
|
844 |
+
# create adam optimizer
|
845 |
+
if "adam" in training_args.optim:
|
846 |
+
optim = optax.adamw(
|
847 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
848 |
+
b1=training_args.adam_beta1,
|
849 |
+
b2=training_args.adam_beta2,
|
850 |
+
eps=training_args.adam_epsilon,
|
851 |
+
weight_decay=training_args.weight_decay,
|
852 |
+
mask=decay_mask_fn,
|
853 |
+
)
|
854 |
+
elif training_args.optim == "adafactor":
|
855 |
+
optim = optax.adafactor(
|
856 |
+
learning_rate=linear_decay_lr_schedule_fn,
|
857 |
+
dtype_momentum=getattr(jnp, model_args.dtype),
|
858 |
+
eps=training_args.adam_epsilon,
|
859 |
+
weight_decay_rate=training_args.weight_decay,
|
860 |
+
weight_decay_mask=decay_mask_fn,
|
861 |
+
)
|
862 |
+
else:
|
863 |
+
raise ValueError(f"Got unknown optmiser {training_args.optim}. Should be one of `adamw` or `adafactor`")
|
864 |
+
|
865 |
+
# Setup train state
|
866 |
+
state = TrainState.create(apply_fn=model.__call__, params=params, tx=optim, dropout_rng=dropout_rng)
|
867 |
+
|
868 |
+
# label smoothed cross entropy
|
869 |
+
def loss_fn(logits, labels, label_smoothing_factor=0.0):
|
870 |
+
"""
|
871 |
+
The label smoothing implementation is adapted from Flax's official example:
|
872 |
+
https://github.com/google/flax/blob/87a211135c6a377c8f29048a1cac3840e38b9da4/examples/wmt/train.py#L104
|
873 |
+
"""
|
874 |
+
vocab_size = logits.shape[-1]
|
875 |
+
confidence = 1.0 - label_smoothing_factor
|
876 |
+
low_confidence = (1.0 - confidence) / (vocab_size - 1)
|
877 |
+
normalizing_constant = -(
|
878 |
+
confidence * jnp.log(confidence) + (vocab_size - 1) * low_confidence * jnp.log(low_confidence + 1e-20)
|
879 |
+
)
|
880 |
+
soft_labels = onehot(labels, vocab_size, on_value=confidence, off_value=low_confidence)
|
881 |
+
|
882 |
+
loss = optax.softmax_cross_entropy(logits, soft_labels)
|
883 |
+
loss = loss - normalizing_constant
|
884 |
+
|
885 |
+
# ignore padded tokens from loss, i.e. where labels are not set to -100
|
886 |
+
padding_mask = labels >= 0
|
887 |
+
loss = loss * padding_mask
|
888 |
+
loss = loss.sum()
|
889 |
+
num_labels = padding_mask.sum()
|
890 |
+
return loss, num_labels
|
891 |
+
|
892 |
+
# Define gradient update step fn
|
893 |
+
def train_step(state, batch, freeze_encoder, label_smoothing_factor=0.0):
|
894 |
+
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
895 |
+
|
896 |
+
def compute_loss(params):
|
897 |
+
labels = batch.pop("labels")
|
898 |
+
logits = state.apply_fn(
|
899 |
+
**batch,
|
900 |
+
params=params,
|
901 |
+
dropout_rng=dropout_rng,
|
902 |
+
freeze_encoder=freeze_encoder,
|
903 |
+
train=True,
|
904 |
+
)[0]
|
905 |
+
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
|
906 |
+
return loss, num_labels
|
907 |
+
|
908 |
+
grad_fn = jax.value_and_grad(compute_loss, has_aux=True)
|
909 |
+
(loss, num_labels), grad = grad_fn(state.params)
|
910 |
+
num_labels = jax.lax.psum(num_labels, "batch")
|
911 |
+
|
912 |
+
# true loss = total loss / total samples
|
913 |
+
loss = jax.lax.psum(loss, "batch")
|
914 |
+
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
915 |
+
|
916 |
+
# true grad = total grad / total samples
|
917 |
+
grad = jax.lax.psum(grad, "batch")
|
918 |
+
grad = jax.tree_util.tree_map(lambda x: x / num_labels, grad)
|
919 |
+
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
920 |
+
|
921 |
+
metrics = {
|
922 |
+
"loss": loss,
|
923 |
+
"learning_rate": linear_decay_lr_schedule_fn(state.step),
|
924 |
+
}
|
925 |
+
return new_state, metrics
|
926 |
+
|
927 |
+
# Define eval fn
|
928 |
+
def eval_step(params, batch, label_smoothing_factor=0.0):
|
929 |
+
labels = batch.pop("labels")
|
930 |
+
logits = model(**batch, params=params, train=False)[0]
|
931 |
+
|
932 |
+
loss, num_labels = loss_fn(logits, labels, label_smoothing_factor)
|
933 |
+
num_labels = jax.lax.psum(num_labels, "batch")
|
934 |
+
|
935 |
+
# true loss = total loss / total samples
|
936 |
+
loss = jax.lax.psum(loss, "batch")
|
937 |
+
loss = jax.tree_util.tree_map(lambda x: x / num_labels, loss)
|
938 |
+
|
939 |
+
metrics = {"loss": loss}
|
940 |
+
return metrics
|
941 |
+
|
942 |
+
# Define generation function
|
943 |
+
num_beams = (
|
944 |
+
training_args.generation_num_beams
|
945 |
+
if training_args.generation_num_beams is not None
|
946 |
+
else model.config.num_beams
|
947 |
+
)
|
948 |
+
gen_kwargs = {"max_length": max_label_length, "num_beams": num_beams}
|
949 |
+
|
950 |
+
def generate_step(params, batch):
|
951 |
+
output_ids = model.generate(
|
952 |
+
batch[model_input_name],
|
953 |
+
attention_mask=batch.get("attention_mask"),
|
954 |
+
params=params,
|
955 |
+
**gen_kwargs,
|
956 |
+
)
|
957 |
+
return output_ids.sequences
|
958 |
+
|
959 |
+
# Create parallel version of the train and eval step
|
960 |
+
p_train_step = jax.pmap(
|
961 |
+
partial(train_step, label_smoothing_factor=training_args.label_smoothing_factor),
|
962 |
+
"batch",
|
963 |
+
donate_argnums=(0,),
|
964 |
+
static_broadcasted_argnums=(2,),
|
965 |
+
)
|
966 |
+
p_eval_step = jax.pmap(
|
967 |
+
partial(eval_step, label_smoothing_factor=training_args.label_smoothing_factor),
|
968 |
+
"batch",
|
969 |
+
)
|
970 |
+
p_generate_step = jax.pmap(generate_step, "batch")
|
971 |
+
|
972 |
+
# Replicate the train state on each device
|
973 |
+
state = state.replicate()
|
974 |
+
|
975 |
+
logger.info("***** Running training *****")
|
976 |
+
logger.info(f" Num examples = {len(vectorized_datasets['train'])}")
|
977 |
+
logger.info(f" Num Epochs = {num_epochs}")
|
978 |
+
logger.info(" Instantaneous batch size per device =" f" {training_args.per_device_train_batch_size}")
|
979 |
+
logger.info(f" Total train batch size (w. parallel & distributed) = {train_batch_size}")
|
980 |
+
logger.info(f" Total optimization steps = {total_train_steps}")
|
981 |
+
|
982 |
+
train_time = 0
|
983 |
+
epochs = tqdm(range(num_epochs), desc=f"Epoch ... (1/{num_epochs})", position=0)
|
984 |
+
for epoch in epochs:
|
985 |
+
# ======================== Training ================================
|
986 |
+
train_start = time.time()
|
987 |
+
|
988 |
+
# Create sampling rng
|
989 |
+
rng, input_rng = jax.random.split(rng)
|
990 |
+
train_metrics = []
|
991 |
+
|
992 |
+
# Generate an epoch by shuffling sampling indices from the train dataset
|
993 |
+
train_loader = get_data_loader(
|
994 |
+
input_rng,
|
995 |
+
vectorized_datasets["train"],
|
996 |
+
batch_size=train_batch_size,
|
997 |
+
data_collator=data_collator,
|
998 |
+
dataloader_num_workers=dataloader_num_workers,
|
999 |
+
)
|
1000 |
+
# train
|
1001 |
+
for step, batch in enumerate(tqdm(train_loader, desc="Training...", position=1), 1):
|
1002 |
+
batch = shard(batch.data)
|
1003 |
+
state, train_metric = p_train_step(state, batch, training_args.freeze_encoder)
|
1004 |
+
|
1005 |
+
cur_step = epoch * steps_per_epoch + step
|
1006 |
+
if cur_step % training_args.logging_steps == 0:
|
1007 |
+
train_metrics.append(train_metric)
|
1008 |
+
train_metric_to_write = unreplicate(train_metric)
|
1009 |
+
epochs.write(
|
1010 |
+
f"Step... ({cur_step} / {total_train_steps} | Loss:"
|
1011 |
+
f" {train_metric_to_write['loss']}, Learning Rate:"
|
1012 |
+
f" {train_metric_to_write['learning_rate']})"
|
1013 |
+
)
|
1014 |
+
if has_wandb and jax.process_index() == 0:
|
1015 |
+
write_wandb_metric(
|
1016 |
+
wandb_logger,
|
1017 |
+
train_metric_to_write,
|
1018 |
+
train_time + time.time() - train_start,
|
1019 |
+
cur_step,
|
1020 |
+
"train",
|
1021 |
+
)
|
1022 |
+
|
1023 |
+
train_time += time.time() - train_start
|
1024 |
+
|
1025 |
+
train_metric = unreplicate(train_metric)
|
1026 |
+
|
1027 |
+
epochs.write(
|
1028 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']},"
|
1029 |
+
f" Learning Rate: {train_metric['learning_rate']})"
|
1030 |
+
)
|
1031 |
+
|
1032 |
+
# ======================== Evaluating ==============================
|
1033 |
+
eval_metrics = []
|
1034 |
+
eval_preds = []
|
1035 |
+
eval_labels = []
|
1036 |
+
eval_start = time.time()
|
1037 |
+
|
1038 |
+
eval_loader = get_data_loader(
|
1039 |
+
input_rng,
|
1040 |
+
vectorized_datasets["eval"],
|
1041 |
+
batch_size=eval_batch_size,
|
1042 |
+
data_collator=data_collator,
|
1043 |
+
shuffle=False,
|
1044 |
+
drop_last=False,
|
1045 |
+
dataloader_num_workers=dataloader_num_workers,
|
1046 |
+
)
|
1047 |
+
for batch in tqdm(eval_loader, desc="Evaluating...", position=2):
|
1048 |
+
# Model forward
|
1049 |
+
labels = batch["labels"]
|
1050 |
+
|
1051 |
+
metrics = pad_shard_unpad(p_eval_step, static_return=True)(
|
1052 |
+
state.params, batch.data, min_device_batch=per_device_eval_batch_size
|
1053 |
+
)
|
1054 |
+
eval_metrics.append(metrics)
|
1055 |
+
|
1056 |
+
# generation
|
1057 |
+
if training_args.predict_with_generate:
|
1058 |
+
generated_ids = pad_shard_unpad(p_generate_step)(
|
1059 |
+
state.params, batch.data, min_device_batch=per_device_eval_batch_size
|
1060 |
+
)
|
1061 |
+
eval_preds.extend(jax.device_get(generated_ids.reshape(-1, gen_kwargs["max_length"])))
|
1062 |
+
eval_labels.extend(labels)
|
1063 |
+
|
1064 |
+
eval_time = time.time() - eval_start
|
1065 |
+
|
1066 |
+
# normalize eval metrics
|
1067 |
+
eval_metrics = get_metrics(eval_metrics)
|
1068 |
+
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
1069 |
+
|
1070 |
+
# compute WER metric
|
1071 |
+
wer_desc = ""
|
1072 |
+
if training_args.predict_with_generate:
|
1073 |
+
wer_metric, pred_str, label_str = compute_metrics(eval_preds, eval_labels)
|
1074 |
+
eval_metrics.update(wer_metric)
|
1075 |
+
wer_desc = " ".join([f"Eval {key}: {value} |" for key, value in wer_metric.items()])
|
1076 |
+
|
1077 |
+
# Print metrics and update progress bar
|
1078 |
+
desc = f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']} |" f" {wer_desc})"
|
1079 |
+
epochs.write(desc)
|
1080 |
+
epochs.desc = desc
|
1081 |
+
|
1082 |
+
# Save metrics
|
1083 |
+
if has_tensorboard and jax.process_index() == 0:
|
1084 |
+
write_metric(
|
1085 |
+
summary_writer,
|
1086 |
+
train_metrics,
|
1087 |
+
eval_metrics,
|
1088 |
+
train_time,
|
1089 |
+
cur_step,
|
1090 |
+
training_args.logging_steps,
|
1091 |
+
)
|
1092 |
+
|
1093 |
+
if has_wandb and jax.process_index() == 0:
|
1094 |
+
write_wandb_metric(wandb_logger, eval_metrics, eval_time, cur_step, "eval")
|
1095 |
+
if training_args.predict_with_generate:
|
1096 |
+
write_wandb_pred(wandb_logger, pred_str, label_str)
|
1097 |
+
|
1098 |
+
# save checkpoint after each epoch and push checkpoint to the hub
|
1099 |
+
if jax.process_index() == 0:
|
1100 |
+
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
1101 |
+
model.save_pretrained(training_args.output_dir, params=params)
|
1102 |
+
tokenizer.save_pretrained(training_args.output_dir)
|
1103 |
+
if training_args.push_to_hub:
|
1104 |
+
repo.push_to_hub(
|
1105 |
+
commit_message=f"Saving weights and logs of epoch {epoch + 1}",
|
1106 |
+
blocking=False,
|
1107 |
+
)
|
1108 |
+
|
1109 |
+
|
1110 |
+
if __name__ == "__main__":
|
1111 |
+
main()
|
special_tokens_map.json
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"additional_special_tokens": [
|
3 |
+
"<|endoftext|>",
|
4 |
+
"<|startoftranscript|>",
|
5 |
+
"<|en|>",
|
6 |
+
"<|zh|>",
|
7 |
+
"<|de|>",
|
8 |
+
"<|es|>",
|
9 |
+
"<|ru|>",
|
10 |
+
"<|ko|>",
|
11 |
+
"<|fr|>",
|
12 |
+
"<|ja|>",
|
13 |
+
"<|pt|>",
|
14 |
+
"<|tr|>",
|
15 |
+
"<|pl|>",
|
16 |
+
"<|ca|>",
|
17 |
+
"<|nl|>",
|
18 |
+
"<|ar|>",
|
19 |
+
"<|sv|>",
|
20 |
+
"<|it|>",
|
21 |
+
"<|id|>",
|
22 |
+
"<|hi|>",
|
23 |
+
"<|fi|>",
|
24 |
+
"<|vi|>",
|
25 |
+
"<|he|>",
|
26 |
+
"<|uk|>",
|
27 |
+
"<|el|>",
|
28 |
+
"<|ms|>",
|
29 |
+
"<|cs|>",
|
30 |
+
"<|ro|>",
|
31 |
+
"<|da|>",
|
32 |
+
"<|hu|>",
|
33 |
+
"<|ta|>",
|
34 |
+
"<|no|>",
|
35 |
+
"<|th|>",
|
36 |
+
"<|ur|>",
|
37 |
+
"<|hr|>",
|
38 |
+
"<|bg|>",
|
39 |
+
"<|lt|>",
|
40 |
+
"<|la|>",
|
41 |
+
"<|mi|>",
|
42 |
+
"<|ml|>",
|
43 |
+
"<|cy|>",
|
44 |
+
"<|sk|>",
|
45 |
+
"<|te|>",
|
46 |
+
"<|fa|>",
|
47 |
+
"<|lv|>",
|
48 |
+
"<|bn|>",
|
49 |
+
"<|sr|>",
|
50 |
+
"<|az|>",
|
51 |
+
"<|sl|>",
|
52 |
+
"<|kn|>",
|
53 |
+
"<|et|>",
|
54 |
+
"<|mk|>",
|
55 |
+
"<|br|>",
|
56 |
+
"<|eu|>",
|
57 |
+
"<|is|>",
|
58 |
+
"<|hy|>",
|
59 |
+
"<|ne|>",
|
60 |
+
"<|mn|>",
|
61 |
+
"<|bs|>",
|
62 |
+
"<|kk|>",
|
63 |
+
"<|sq|>",
|
64 |
+
"<|sw|>",
|
65 |
+
"<|gl|>",
|
66 |
+
"<|mr|>",
|
67 |
+
"<|pa|>",
|
68 |
+
"<|si|>",
|
69 |
+
"<|km|>",
|
70 |
+
"<|sn|>",
|
71 |
+
"<|yo|>",
|
72 |
+
"<|so|>",
|
73 |
+
"<|af|>",
|
74 |
+
"<|oc|>",
|
75 |
+
"<|ka|>",
|
76 |
+
"<|be|>",
|
77 |
+
"<|tg|>",
|
78 |
+
"<|sd|>",
|
79 |
+
"<|gu|>",
|
80 |
+
"<|am|>",
|
81 |
+
"<|yi|>",
|
82 |
+
"<|lo|>",
|
83 |
+
"<|uz|>",
|
84 |
+
"<|fo|>",
|
85 |
+
"<|ht|>",
|
86 |
+
"<|ps|>",
|
87 |
+
"<|tk|>",
|
88 |
+
"<|nn|>",
|
89 |
+
"<|mt|>",
|
90 |
+
"<|sa|>",
|
91 |
+
"<|lb|>",
|
92 |
+
"<|my|>",
|
93 |
+
"<|bo|>",
|
94 |
+
"<|tl|>",
|
95 |
+
"<|mg|>",
|
96 |
+
"<|as|>",
|
97 |
+
"<|tt|>",
|
98 |
+
"<|haw|>",
|
99 |
+
"<|ln|>",
|
100 |
+
"<|ha|>",
|
101 |
+
"<|ba|>",
|
102 |
+
"<|jw|>",
|
103 |
+
"<|su|>",
|
104 |
+
"<|translate|>",
|
105 |
+
"<|transcribe|>",
|
106 |
+
"<|startoflm|>",
|
107 |
+
"<|startofprev|>",
|
108 |
+
"<|nocaptions|>",
|
109 |
+
"<|notimestamps|>"
|
110 |
+
],
|
111 |
+
"bos_token": "<|endoftext|>",
|
112 |
+
"eos_token": "<|endoftext|>",
|
113 |
+
"pad_token": "<|endoftext|>",
|
114 |
+
"unk_token": "<|endoftext|>"
|
115 |
+
}
|
tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
tokenizer_config.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
vocab.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|