PoTaTo721 commited on
Commit
28c720a
·
1 Parent(s): 315fa0c

Update to V1.4

Browse files
Files changed (50) hide show
  1. app.py +7 -7
  2. fish_speech/configs/firefly_gan_vq.yaml +2 -3
  3. fish_speech/configs/text2semantic_finetune.yaml +1 -1
  4. fish_speech/i18n/README.md +27 -0
  5. fish_speech/i18n/__init__.py +3 -0
  6. fish_speech/i18n/core.py +40 -0
  7. fish_speech/i18n/locale/en_US.json +122 -0
  8. fish_speech/i18n/locale/es_ES.json +122 -0
  9. fish_speech/i18n/locale/ja_JP.json +123 -0
  10. fish_speech/i18n/locale/pt_BR.json +133 -0
  11. fish_speech/i18n/locale/zh_CN.json +122 -0
  12. fish_speech/i18n/scan.py +122 -0
  13. fish_speech/models/text2semantic/llama.py +27 -0
  14. fish_speech/models/vqgan/__init__.py +0 -3
  15. fish_speech/models/vqgan/modules/firefly.py +167 -196
  16. fish_speech/models/vqgan/modules/fsq.py +4 -27
  17. fish_speech/scheduler.py +19 -1
  18. fish_speech/text/clean.py +1 -1
  19. fish_speech/train.py +5 -1
  20. fish_speech/utils/__init__.py +2 -0
  21. fish_speech/utils/context.py +13 -0
  22. fish_speech/utils/file.py +0 -103
  23. fish_speech/webui/css/style.css +161 -0
  24. fish_speech/webui/html/footer.html +11 -0
  25. fish_speech/webui/js/animate.js +69 -0
  26. fish_speech/webui/launch_utils.py +120 -0
  27. fish_speech/webui/manage.py +1237 -0
  28. requirements.txt +2 -1
  29. tools/api.py +93 -135
  30. tools/commons.py +35 -0
  31. tools/download_models.py +55 -0
  32. tools/extract_model.py +21 -0
  33. tools/file.py +125 -0
  34. tools/llama/build_dataset.py +1 -1
  35. tools/llama/generate.py +28 -10
  36. tools/llama/merge_lora.py +1 -1
  37. tools/llama/quantize.py +2 -2
  38. tools/msgpack_api.py +34 -0
  39. tools/post_api.py +205 -0
  40. tools/sensevoice/README.md +59 -0
  41. tools/sensevoice/__init__.py +0 -0
  42. tools/sensevoice/auto_model.py +573 -0
  43. tools/sensevoice/fun_asr.py +332 -0
  44. tools/sensevoice/vad_utils.py +61 -0
  45. tools/smart_pad.py +47 -0
  46. tools/vqgan/create_train_split.py +1 -1
  47. tools/vqgan/extract_vq.py +3 -3
  48. tools/vqgan/inference.py +5 -3
  49. tools/webui.py +619 -0
  50. tools/whisper_asr.py +176 -0
app.py CHANGED
@@ -10,7 +10,7 @@ import gc
10
 
11
  # Download if not exists
12
  os.makedirs("checkpoints", exist_ok=True)
13
- snapshot_download(repo_id="fishaudio/fish-speech-1.2-sft", local_dir="./checkpoints/fish-speech-1.2-sft")
14
 
15
  print("All checkpoints downloaded")
16
 
@@ -46,8 +46,8 @@ os.environ["EINX_FILTER_TRACEBACK"] = "false"
46
 
47
  HEADER_MD = """# Fish Speech
48
 
49
- ## The demo in this space is version 1.2, Please check [Fish Audio](https://fish.audio) for the best model.
50
- ## 该 Demo 为 Fish Speech 1.2 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
51
 
52
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
53
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
@@ -61,8 +61,8 @@ Related code and weights are released under CC BY-NC-SA 4.0 License.
61
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
62
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
63
 
64
- The model running in this WebUI is Fish Speech V1.2 Medium SFT.
65
- 在此 WebUI 中运行的模型是 Fish Speech V1.2 Medium SFT.
66
  """
67
 
68
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
@@ -560,12 +560,12 @@ def parse_args():
560
  parser.add_argument(
561
  "--llama-checkpoint-path",
562
  type=Path,
563
- default="checkpoints/fish-speech-1.2-sft",
564
  )
565
  parser.add_argument(
566
  "--decoder-checkpoint-path",
567
  type=Path,
568
- default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
569
  )
570
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
571
  parser.add_argument("--device", type=str, default="cuda")
 
10
 
11
  # Download if not exists
12
  os.makedirs("checkpoints", exist_ok=True)
13
+ snapshot_download(repo_id="fishaudio/fish-speech-1.4", local_dir="./checkpoints/fish-speech-1.4")
14
 
15
  print("All checkpoints downloaded")
16
 
 
46
 
47
  HEADER_MD = """# Fish Speech
48
 
49
+ ## The demo in this space is version 1.4, Please check [Fish Audio](https://fish.audio) for the best model.
50
+ ## 该 Demo 为 Fish Speech 1.4 版本, 请在 [Fish Audio](https://fish.audio) 体验最新 DEMO.
51
 
52
  A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).
53
  由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.
 
61
  We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.
62
  我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.
63
 
64
+ The model running in this WebUI is Fish Speech V1.4 Medium.
65
+ 在此 WebUI 中运行的模型是 Fish Speech V1.4 Medium.
66
  """
67
 
68
  TEXTBOX_PLACEHOLDER = """Put your text here. 在此处输入文本."""
 
560
  parser.add_argument(
561
  "--llama-checkpoint-path",
562
  type=Path,
563
+ default="checkpoints/fish-speech-1.4",
564
  )
565
  parser.add_argument(
566
  "--decoder-checkpoint-path",
567
  type=Path,
568
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
569
  )
570
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
571
  parser.add_argument("--device", type=str, default="cuda")
fish_speech/configs/firefly_gan_vq.yaml CHANGED
@@ -22,13 +22,12 @@ head:
22
  resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23
  num_mels: 512
24
  upsample_initial_channel: 512
25
- use_template: false
26
  pre_conv_kernel_size: 13
27
  post_conv_kernel_size: 13
28
  quantizer:
29
  _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
30
  input_dim: 512
31
- n_groups: 4
32
  n_codebooks: 1
33
  levels: [8, 5, 5, 5]
34
- downsample_factor: [2]
 
22
  resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
23
  num_mels: 512
24
  upsample_initial_channel: 512
 
25
  pre_conv_kernel_size: 13
26
  post_conv_kernel_size: 13
27
  quantizer:
28
  _target_: fish_speech.models.vqgan.modules.fsq.DownsampleFiniteScalarQuantize
29
  input_dim: 512
30
+ n_groups: 8
31
  n_codebooks: 1
32
  levels: [8, 5, 5, 5]
33
+ downsample_factor: [2, 2]
fish_speech/configs/text2semantic_finetune.yaml CHANGED
@@ -4,7 +4,7 @@ defaults:
4
 
5
  project: text2semantic_finetune_dual_ar
6
  max_length: 4096
7
- pretrained_ckpt_path: checkpoints/fish-speech-1.2-sft
8
 
9
  # Lightning Trainer
10
  trainer:
 
4
 
5
  project: text2semantic_finetune_dual_ar
6
  max_length: 4096
7
+ pretrained_ckpt_path: checkpoints/fish-speech-1.4
8
 
9
  # Lightning Trainer
10
  trainer:
fish_speech/i18n/README.md ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## i18n Folder Attribution
2
+
3
+ The `i18n` folder within the `fish_speech` directory contains files initially sourced from the RVC project. In compliance with the MIT license under which these files were released, we acknowledge the original authors and sources below:
4
+
5
+ ### fish_speech/i18n/core.py
6
+
7
+ **Related code from RVC:**
8
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/i18n.py)
9
+
10
+ **Initial commit:**
11
+ add localization(添加本地化) [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#35](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/35)
12
+
13
+ **Initial author:**
14
+ [@L4Ph](https://github.com/L4Ph)
15
+
16
+ ### fish_speech/i18n/scan.py
17
+
18
+ **Related code from RVC:**
19
+ [https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/blob/83d6a64e675d9bbd6e92ee450c5f807ed2bb54d8/i18n/scan_i18n.py)
20
+
21
+ **Initial commit:**
22
+ File for detecting i18n missing keys [RVC-Project/Retrieval-based-Voice-Conversion-WebUI#1058](https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI/pull/1058)
23
+
24
+ **Initial author:**
25
+ [@towzeur](https://github.com/towzeur)
26
+
27
+ We appreciate the contributions of the RVC project and its authors.
fish_speech/i18n/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .core import i18n
2
+
3
+ __all__ = ["i18n"]
fish_speech/i18n/core.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import locale
3
+ from pathlib import Path
4
+
5
+ I18N_FILE_PATH = Path(__file__).parent / "locale"
6
+ DEFAULT_LANGUAGE = "en_US"
7
+
8
+
9
+ def load_language_list(language):
10
+ with open(I18N_FILE_PATH / f"{language}.json", "r", encoding="utf-8") as f:
11
+ language_list = json.load(f)
12
+
13
+ return language_list
14
+
15
+
16
+ class I18nAuto:
17
+ def __init__(self):
18
+ i18n_file = Path(".locale")
19
+
20
+ if i18n_file.exists():
21
+ with open(i18n_file, "r", encoding="utf-8") as f:
22
+ language = f.read().strip()
23
+ else:
24
+ # getlocale can't identify the system's language ((None, None))
25
+ language = locale.getdefaultlocale()[0]
26
+
27
+ if (I18N_FILE_PATH / f"{language}.json").exists() is False:
28
+ language = DEFAULT_LANGUAGE
29
+
30
+ self.language = language
31
+ self.language_map = load_language_list(language)
32
+
33
+ def __call__(self, key):
34
+ return self.language_map.get(key, key)
35
+
36
+ def __repr__(self):
37
+ return "Use Language: " + self.language
38
+
39
+
40
+ i18n = I18nAuto()
fish_speech/i18n/locale/en_US.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "16-mixed is recommended for 10+ series GPU",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 to 10 seconds of reference audio, useful for specifying speaker.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Accumulate Gradient Batches",
6
+ "Add to Processing Area": "Add to Processing Area",
7
+ "Added path successfully!": "Added path successfully!",
8
+ "Advanced Config": "Advanced Config",
9
+ "Base LLAMA Model": "Base LLAMA Model",
10
+ "Batch Inference": "Batch Inference",
11
+ "Batch Size": "Batch Size",
12
+ "Changing with the Model Path": "Changing with the Model Path",
13
+ "Chinese": "Chinese",
14
+ "Compile Model": "Compile Model",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compile the model can significantly reduce the inference time, but will increase cold start time",
16
+ "Copy": "Copy",
17
+ "Data Preprocessing": "Data Preprocessing",
18
+ "Data Preprocessing Path": "Data Preprocessing Path",
19
+ "Data Source": "Data Source",
20
+ "Decoder Model Config": "Decoder Model Config",
21
+ "Decoder Model Path": "Decoder Model Path",
22
+ "Disabled": "Disabled",
23
+ "Enable Reference Audio": "Enable Reference Audio",
24
+ "English": "English",
25
+ "Error Message": "Error Message",
26
+ "File Preprocessing": "File Preprocessing",
27
+ "Generate": "Generate",
28
+ "Generated Audio": "Generated Audio",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format",
30
+ "Infer interface is closed": "Infer interface is closed",
31
+ "Inference Configuration": "Inference Configuration",
32
+ "Inference Server Configuration": "Inference Server Configuration",
33
+ "Inference Server Error": "Inference Server Error",
34
+ "Inferring interface is launched at {}": "Inferring interface is launched at {}",
35
+ "Initial Learning Rate": "Initial Learning Rate",
36
+ "Input Audio & Source Path for Transcription": "Input Audio & Source Path for Transcription",
37
+ "Input Text": "Input Text",
38
+ "Invalid path: {}": "Invalid path: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "It is recommended to use CUDA, if you have low configuration, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Iterative Prompt Length, 0 means off",
41
+ "Japanese": "Japanese",
42
+ "LLAMA Configuration": "LLAMA Configuration",
43
+ "LLAMA Model Config": "LLAMA Model Config",
44
+ "LLAMA Model Path": "LLAMA Model Path",
45
+ "Labeling Device": "Labeling Device",
46
+ "LoRA Model to be merged": "LoRA Model to be merged",
47
+ "Maximum Audio Duration": "Maximum Audio Duration",
48
+ "Maximum Length per Sample": "Maximum Length per Sample",
49
+ "Maximum Training Steps": "Maximum Training Steps",
50
+ "Maximum tokens per batch, 0 means no limit": "Maximum tokens per batch, 0 means no limit",
51
+ "Merge": "Merge",
52
+ "Merge LoRA": "Merge LoRA",
53
+ "Merge successfully": "Merge successfully",
54
+ "Minimum Audio Duration": "Minimum Audio Duration",
55
+ "Model Output Path": "Model Output Path",
56
+ "Model Size": "Model Size",
57
+ "Move": "Move",
58
+ "Move files successfully": "Move files successfully",
59
+ "No audio generated, please check the input text.": "No audio generated, please check the input text.",
60
+ "No selected options": "No selected options",
61
+ "Number of Workers": "Number of Workers",
62
+ "Open Inference Server": "Open Inference Server",
63
+ "Open Labeler WebUI": "Open Labeler WebUI",
64
+ "Open Tensorboard": "Open Tensorboard",
65
+ "Opened labeler in browser": "Opened labeler in browser",
66
+ "Optional Label Language": "Optional Label Language",
67
+ "Optional online ver": "Optional online ver",
68
+ "Output Path": "Output Path",
69
+ "Path error, please check the model file exists in the corresponding path": "Path error, please check the model file exists in the corresponding path",
70
+ "Precision": "Precision",
71
+ "Probability of applying Speaker Condition": "Probability of applying Speaker Condition",
72
+ "Put your text here.": "Put your text here.",
73
+ "Reference Audio": "Reference Audio",
74
+ "Reference Text": "Reference Text",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "Related code and weights are released under CC BY-NC-SA 4.0 License.",
76
+ "Remove Selected Data": "Remove Selected Data",
77
+ "Removed path successfully!": "Removed path successfully!",
78
+ "Repetition Penalty": "Repetition Penalty",
79
+ "Save model every n steps": "Save model every n steps",
80
+ "Select LLAMA ckpt": "Select LLAMA ckpt",
81
+ "Select VITS ckpt": "Select VITS ckpt",
82
+ "Select VQGAN ckpt": "Select VQGAN ckpt",
83
+ "Select source file processing method": "Select source file processing method",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Select the model to be trained (Depending on the Tab page you are on)",
85
+ "Selected: {}": "Selected: {}",
86
+ "Speaker": "Speaker",
87
+ "Speaker is identified by the folder name": "Speaker is identified by the folder name",
88
+ "Start Training": "Start Training",
89
+ "Streaming Audio": "Streaming Audio",
90
+ "Streaming Generate": "Streaming Generate",
91
+ "Tensorboard Host": "Tensorboard Host",
92
+ "Tensorboard Log Path": "Tensorboard Log Path",
93
+ "Tensorboard Port": "Tensorboard Port",
94
+ "Tensorboard interface is closed": "Tensorboard interface is closed",
95
+ "Tensorboard interface is launched at {}": "Tensorboard interface is launched at {}",
96
+ "Text is too long, please keep it under {} characters.": "Text is too long, please keep it under {} characters.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.",
98
+ "Training Configuration": "Training Configuration",
99
+ "Training Error": "Training Error",
100
+ "Training stopped": "Training stopped",
101
+ "Type name of the speaker": "Type name of the speaker",
102
+ "Type the path or select from the dropdown": "Type the path or select from the dropdown",
103
+ "Use LoRA": "Use LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Use LoRA can save GPU memory, but may reduce the quality of the model",
105
+ "Use filelist": "Use filelist",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use large for 10G+ GPU, medium for 5G, small for 2G",
107
+ "VITS Configuration": "VITS Configuration",
108
+ "VQGAN Configuration": "VQGAN Configuration",
109
+ "Validation Batch Size": "Validation Batch Size",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "View the status of the preprocessing folder (use the slider to control the depth of the tree)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.",
112
+ "WebUI Host": "WebUI Host",
113
+ "WebUI Port": "WebUI Port",
114
+ "Whisper Model": "Whisper Model",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU",
117
+ "latest": "latest",
118
+ "new": "new",
119
+ "Realtime Transform Text": "Realtime Transform Text",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Normalization Result Preview (Currently Only Chinese)",
121
+ "Text Normalization": "Text Normalization"
122
+ }
fish_speech/i18n/locale/es_ES.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "se recomienda 16-mixed para GPU de la serie 10+",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de audio de referencia, útil para especificar el hablante.",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Un modelo de texto a voz basado en VQ-GAN y Llama desarrollado por [Fish Audio](https://fish.audio).",
5
+ "Accumulate Gradient Batches": "Acumular lotes de gradientes",
6
+ "Add to Processing Area": "Agregar al Área de Procesamiento",
7
+ "Added path successfully!": "¡Ruta agregada exitosamente!",
8
+ "Advanced Config": "Configuración Avanzada",
9
+ "Base LLAMA Model": "Modelo Base LLAMA",
10
+ "Batch Inference": "Inferencia por Lote",
11
+ "Batch Size": "Tamaño del Lote",
12
+ "Changing with the Model Path": "Cambiando con la Ruta del Modelo",
13
+ "Chinese": "Chino",
14
+ "Compile Model": "Compilar Modelo",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar el modelo puede reducir significativamente el tiempo de inferencia, pero aumentará el tiempo de inicio en frío",
16
+ "Copy": "Copiar",
17
+ "Data Preprocessing": "Preprocesamiento de Datos",
18
+ "Data Preprocessing Path": "Ruta de Preprocesamiento de Datos",
19
+ "Data Source": "Fuente de Datos",
20
+ "Decoder Model Config": "Configuración del modelo decodificador",
21
+ "Decoder Model Path": "Ruta del modelo decodificador",
22
+ "Disabled": "Desactivado",
23
+ "Enable Reference Audio": "Habilitar Audio de Referencia",
24
+ "English": "Inglés",
25
+ "Error Message": "Mensaje de Error",
26
+ "File Preprocessing": "Preprocesamiento de Archivos",
27
+ "Generate": "Generar",
28
+ "Generated Audio": "Audio Generado",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Si no hay texto correspondiente para el audio, aplique ASR para asistencia, soporte para formato .txt o .lab",
30
+ "Infer interface is closed": "La interfaz de inferencia está cerrada",
31
+ "Inference Configuration": "Configuración de Inferencia",
32
+ "Inference Server Configuration": "Configuración del Servidor de Inferencia",
33
+ "Inference Server Error": "Error del Servidor de Inferencia",
34
+ "Inferring interface is launched at {}": "La interfaz de inferencia se ha lanzado en {}",
35
+ "Initial Learning Rate": "Tasa de Aprendizaje Inicial",
36
+ "Input Audio & Source Path for Transcription": "Audio de Entrada y Ruta de Origen para Transcripción",
37
+ "Input Text": "Texto de Entrada",
38
+ "Invalid path: {}": "Ruta inválida: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Se recomienda usar CUDA, si tiene una configuración baja, use CPU",
40
+ "Iterative Prompt Length, 0 means off": "Longitud de la Indicación Iterativa, 0 significa apagado",
41
+ "Japanese": "Japonés",
42
+ "LLAMA Configuration": "Configuración de LLAMA",
43
+ "LLAMA Model Config": "Configuración del Modelo LLAMA",
44
+ "LLAMA Model Path": "Ruta del Modelo LLAMA",
45
+ "Labeling Device": "Dispositivo de Etiquetado",
46
+ "LoRA Model to be merged": "Modelo LoRA a fusionar",
47
+ "Maximum Audio Duration": "Duración máxima de audio",
48
+ "Maximum Length per Sample": "Longitud Máxima por Muestra",
49
+ "Maximum Training Steps": "Pasos Máximos de Entrenamiento",
50
+ "Maximum tokens per batch, 0 means no limit": "Máximo de tokens por lote, 0 significa sin límite",
51
+ "Merge": "Fusionar",
52
+ "Merge LoRA": "Fusionar LoRA",
53
+ "Merge successfully": "Fusionado exitosamente",
54
+ "Minimum Audio Duration": "Duración mínima de audio",
55
+ "Model Output Path": "Ruta de Salida del Modelo",
56
+ "Model Size": "Tamaño del Modelo",
57
+ "Move": "Mover",
58
+ "Move files successfully": "Archivos movidos exitosamente",
59
+ "No audio generated, please check the input text.": "No se generó audio, por favor verifique el texto de entrada.",
60
+ "No selected options": "No hay opciones seleccionadas",
61
+ "Number of Workers": "Número de Trabajadores",
62
+ "Open Inference Server": "Abrir Servidor de Inferencia",
63
+ "Open Labeler WebUI": "Abrir Interfaz Web del Etiquetador",
64
+ "Open Tensorboard": "Abrir Tensorboard",
65
+ "Opened labeler in browser": "Se abrió el etiquetador en el navegador",
66
+ "Optional Label Language": "Idioma de Etiquetado Opcional",
67
+ "Optional online ver": "Ver en línea opcional",
68
+ "Output Path": "Ruta de Salida",
69
+ "Path error, please check the model file exists in the corresponding path": "Error de ruta, por favor verifique que el archivo del modelo exista en la ruta correspondiente",
70
+ "Precision": "Precisión",
71
+ "Probability of applying Speaker Condition": "Probabilidad de aplicar Condición de Hablante",
72
+ "Put your text here.": "Ponga su texto aquí.",
73
+ "Reference Audio": "Audio de Referencia",
74
+ "Reference Text": "Texto de Referencia",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "El código relacionado y los pesos se publican bajo la Licencia CC BY-NC-SA 4.0.",
76
+ "Remove Selected Data": "Eliminar Datos Seleccionados",
77
+ "Removed path successfully!": "¡Ruta eliminada exitosamente!",
78
+ "Repetition Penalty": "Penalización por Repetición",
79
+ "Save model every n steps": "Guardar modelo cada n pasos",
80
+ "Select LLAMA ckpt": "Seleccionar punto de control LLAMA",
81
+ "Select VITS ckpt": "Seleccionar punto de control VITS",
82
+ "Select VQGAN ckpt": "Seleccionar punto de control VQGAN",
83
+ "Select source file processing method": "Seleccione el método de procesamiento de archivos fuente",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "Seleccione el modelo a entrenar (Dependiendo de la pestaña en la que se encuentre)",
85
+ "Selected: {}": "Seleccionado: {}",
86
+ "Speaker": "Hablante",
87
+ "Speaker is identified by the folder name": "El hablante se identifica por el nombre de la carpeta",
88
+ "Start Training": "Iniciar Entrenamiento",
89
+ "Streaming Audio": "transmisión de audio",
90
+ "Streaming Generate": "síntesis en flujo",
91
+ "Tensorboard Host": "Host de Tensorboard",
92
+ "Tensorboard Log Path": "Ruta de Registro de Tensorboard",
93
+ "Tensorboard Port": "Puerto de Tensorboard",
94
+ "Tensorboard interface is closed": "La interfaz de Tensorboard está cerrada",
95
+ "Tensorboard interface is launched at {}": "La interfaz de Tensorboard se ha lanzado en {}",
96
+ "Text is too long, please keep it under {} characters.": "El texto es demasiado largo, por favor manténgalo por debajo de {} caracteres.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "La ruta de la carpeta de entrada a la izquierda o la lista de archivos. Ya sea que esté marcado o no, se utilizará para el entrenamiento posterior en esta lista.",
98
+ "Training Configuration": "Configuración de Entrenamiento",
99
+ "Training Error": "Error de Entrenamiento",
100
+ "Training stopped": "Entrenamiento detenido",
101
+ "Type name of the speaker": "Escriba el nombre del hablante",
102
+ "Type the path or select from the dropdown": "Escriba la ruta o seleccione de la lista desplegable",
103
+ "Use LoRA": "Usar LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "Usar LoRA puede ahorrar memoria GPU, pero puede reducir la calidad del modelo",
105
+ "Use filelist": "Usar lista de archivos",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "Use grande para GPU de 10G+, mediano para 5G, pequeño para 2G",
107
+ "VITS Configuration": "Configuración de VITS",
108
+ "VQGAN Configuration": "Configuración de VQGAN",
109
+ "Validation Batch Size": "Tamaño del Lote de Validación",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Vea el estado de la carpeta de preprocesamiento (use el control deslizante para controlar la profundidad del árbol)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "No somos responsables de ningún mal uso del modelo, por favor considere sus leyes y regulaciones locales antes de usarlo.",
112
+ "WebUI Host": "Host de WebUI",
113
+ "WebUI Port": "Puerto de WebUI",
114
+ "Whisper Model": "Modelo Whisper",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Puede encontrar el código fuente [aquí](https://github.com/fishaudio/fish-speech) y los modelos [aquí](https://huggingface.co/fishaudio/fish-speech-1).",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "Se recomienda bf16-true para GPU de la serie 30+, se recomienda 16-mixed para GPU de la serie 10+",
117
+ "latest": "más reciente",
118
+ "new": "nuevo",
119
+ "Realtime Transform Text": "Transformación de Texto en Tiempo Real",
120
+ "Normalization Result Preview (Currently Only Chinese)": "Vista Previa del Resultado de Normalización (Actualmente Solo Chino)",
121
+ "Text Normalization": "Normalización de Texto"
122
+ }
fish_speech/i18n/locale/ja_JP.json ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10シリーズ以降のGPUには16-mixedをお勧めします",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "話者を指定するのに役立つ、5~10秒のリファレンスオーディオ。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "[Fish Audio](https://fish.audio)が開発したVQ-GANとLlamaに基づくテキスト音声合成モデル。",
5
+ "Accumulate Gradient Batches": "勾配バッチの累積",
6
+ "Add to Processing Area": "処理エリアに追加",
7
+ "Added path successfully!": "パスの追加に成功しました!",
8
+ "Advanced Config": "詳細設定",
9
+ "Base LLAMA Model": "基本LLAMAモデル",
10
+ "Batch Inference": "バッチ推論",
11
+ "Batch Size": "バッチサイズ",
12
+ "Changing with the Model Path": "モデルのパスに伴って変化する",
13
+ "Chinese": "中国語",
14
+ "Compile Model": "モデルのコンパイル",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "モデルをコンパイルすると推論時間を大幅に短縮できますが、コールドスタート時間が長くなります",
16
+ "Copy": "コピー",
17
+ "Data Preprocessing": "データ前処理",
18
+ "Data Preprocessing Path": "データ前処理パス",
19
+ "Data Source": "データソース",
20
+ "Decoder Model Config": "デコーダーモデルの構成",
21
+ "Decoder Model Path": "デコーダーモデルのパス",
22
+ "Disabled": "無効",
23
+ "Enable Reference Audio": "リファレンスオーディオを有効にする",
24
+ "English": "英語",
25
+ "Error Message": "エラーメッセージ",
26
+ "File Preprocessing": "文書前处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "生成されたオーディオ",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "音声に対応するテキストがない場合は、ASRを適用してサポートします。.txtまたは.lab形式をサポートしています",
30
+ "Infer interface is closed": "推論インターフェースが閉じられています",
31
+ "Inference Configuration": "推論設定",
32
+ "Inference Server Configuration": "推論サーバー設定",
33
+ "Inference Server Error": "推論サーバーエラー",
34
+ "Inferring interface is launched at {}": "推論インターフェースが{}で起動しました",
35
+ "Initial Learning Rate": "初期学習率",
36
+ "Input Audio & Source Path for Transcription": "入力オーディオと文字起こしのソースパス",
37
+ "Input Text": "入力テキスト",
38
+ "Invalid path: {}": "無効なパス: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "CUDAの使用をお勧めします。低い構成の場合はCPUを使用してください",
40
+ "Iterative Prompt Length, 0 means off": "反復プロンプト長。0はオフを意味します",
41
+ "Japanese": "日本語",
42
+ "LLAMA Configuration": "LLAMA設定",
43
+ "LLAMA Model Config": "LLAMAモデル設定",
44
+ "LLAMA Model Path": "LLAMAモデルパス",
45
+ "Labeling Device": "ラベリングデバイス",
46
+ "LoRA Model to be merged": "マージするLoRAモデル",
47
+ "Maximum Audio Duration": "最大オーディオの長さ",
48
+ "Maximum Length per Sample": "サンプルあたりの最大長",
49
+ "Maximum Training Steps": "最大トレーニングステップ数",
50
+ "Maximum tokens per batch, 0 means no limit": "バッチあたりの最大トークン数。0は制限なしを意味します",
51
+ "Merge": "マージ",
52
+ "Merge LoRA": "LoRAのマージ",
53
+ "Merge successfully": "マージに成功しました",
54
+ "Minimum Audio Duration": "最小オーディオの長さ",
55
+ "Model Output Path": "モデル出力パス",
56
+ "Model Size": "モデルサイズ",
57
+ "Move": "移動",
58
+ "Move files successfully": "ファイルの移動に成功しました",
59
+ "No audio generated, please check the input text.": "オーディオが生成されていません。入力テキストを確認してください。",
60
+ "No selected options": "選択されたオプションはありません",
61
+ "Number of Workers": "ワーカー数",
62
+ "Open Inference Server": "推論サーバーを開く",
63
+ "Open Labeler WebUI": "ラベラーWebUIを開く",
64
+ "Open Tensorboard": "Tensorboardを開く",
65
+ "Opened labeler in browser": "ブラウザでラベラーを開きました",
66
+ "Optional Label Language": "オプションのラベル言語",
67
+ "Optional online ver": "オプションのオンラインバージョン",
68
+ "Output Path": "出力パス",
69
+ "Path error, please check the model file exists in the corresponding path": "パスエラー。対応するパスにモデルファイルが存在するか確認してください",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "話者条件を適用する確率",
72
+ "Put your text here.": "ここにテキストを入力してください。",
73
+ "Reference Audio": "リファレンスオーディオ",
74
+ "Reference Text": "リファレンステキスト",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "関連コードと重みはCC BY-NC-SA 4.0ライセンスの下でリリースされます。",
76
+ "Remove Selected Data": "選択したデータを削除",
77
+ "Removed path successfully!": "パスの削除に成功しました!",
78
+ "Repetition Penalty": "反復ペナルティ",
79
+ "Save model every n steps": "nステップごとにモデルを保存",
80
+ "Select LLAMA ckpt": " LLAMA チェックポイントを選択",
81
+ "Select VITS ckpt": "VITS チェックポイントを選択",
82
+ "Select VQGAN ckpt": "VQGAN チェックポイントを選択",
83
+ "Select source file processing method": "ソースファイルの処理方法を選択",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "タブページに応じてトレーニングするモデルを選択してください",
85
+ "Selected: {}": "選択済み: {}",
86
+ "Speaker": "話者",
87
+ "Speaker is identified by the folder name": "話者はフォルダ名で識別されます",
88
+ "Start Training": "トレーニング開始",
89
+ "Streaming Audio": "ストリーミングオーディオ",
90
+ "Streaming Generate": "ストリーミング合成",
91
+ "Tensorboard Host": "Tensorboardホスト",
92
+ "Tensorboard Log Path": "Tensorboardログパス",
93
+ "Tensorboard Port": "Tensorboardポート",
94
+ "Tensorboard interface is closed": "Tensorboardインターフェースが閉じられています",
95
+ "Tensorboard interface is launched at {}": "Tensorboardインターフェースが{}で起動されました",
96
+ "Text is too long, please keep it under {} characters.": "テキストが長すぎます。{}文字以内に抑えてください。",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左側の入力フォルダまたはファイルリストのパス。チェックの有無にかかわらず、このリストの後続のトレーニングに使用されます。",
98
+ "Training Configuration": "トレーニング設定",
99
+ "Training Error": "トレーニングエラー",
100
+ "Training stopped": "トレーニングが停止しました",
101
+ "Type name of the speaker": "話者の名前を入力",
102
+ "Type the path or select from the dropdown": "パスを入力するか、ドロップダウンから選択してください",
103
+ "Use LoRA": "LoRAを使用",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "LoRAを使用するとGPUメモリを節約できますが、モデルの品質が低下する可能性があります",
105
+ "Use filelist": "ファイルリストを使用",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G以上のGPUには大、5Gには中、2Gには小を使用してください",
107
+ "VITS Configuration": "VITS の構成",
108
+ "VQGAN Configuration": "VQGAN の構成",
109
+ "Validation Batch Size": "検証バッチサイズ",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "前処理フォルダの状態を表示(スライダーを使用してツリーの深さを制御)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "モデルの誤用については一切責任を負いません。使用する前に、現地の法律と規制を考慮してください。",
112
+ "WebUI Host": "WebUIホスト",
113
+ "WebUI Port": "WebUIポート",
114
+ "Whisper Model": "Whisperモデル",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "ソースコードは[こちら](https://github.com/fishaudio/fish-speech)、モデルは[こちら](https://huggingface.co/fishaudio/fish-speech-1)にあります。",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30シリーズ以降のGPUにはbf16-trueを、10シリーズ以降のGPUには16-mixedをお勧めします",
117
+ "latest": "最新",
118
+ "new": "新規",
119
+ "Realtime Transform Text": "リアルタイム変換テキスト",
120
+ "Normalization Result Preview (Currently Only Chinese)": "正規化結果プレビュー(現在は中国語のみ)",
121
+ "Text Normalization": "テキスト正規化"
122
+
123
+ }
fish_speech/i18n/locale/pt_BR.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 a 10 segundos de áudio de referência, útil para especificar o orador.",
3
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "Um modelo de texto para fala baseado em VQ-GAN e Llama desenvolvido por [Fish Audio](https://fish.audio).",
4
+ "Accumulate Gradient Batches": "Acumular Lotes de Gradiente",
5
+ "Add to Processing Area": "Adicionar à Área de Processamento",
6
+ "Added path successfully!": "Caminho adicionado com sucesso!",
7
+ "Advanced Config": "Configuração Avançada",
8
+ "Base LLAMA Model": "Modelo LLAMA Base",
9
+ "Batch Inference": "Inferência em Lote",
10
+ "Batch Size": "Tamanho do Lote",
11
+ "Changing with the Model Path": "Alterando com o Caminho do Modelo",
12
+
13
+ "Compile Model": "Compilar Modelo",
14
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "Compilar o modelo pode reduzir significativamente o tempo de inferência, mas aumentará a latência inicial",
15
+ "Copy": "Copiar",
16
+ "Data Preprocessing": "Pré-processamento de Dados",
17
+ "Data Preprocessing Path": "Caminho de Pré-processamento de Dados",
18
+ "Data Source": "Fonte de Dados",
19
+ "Decoder Model Config": "Configuração do Modelo Decodificador",
20
+ "Decoder Model Path": "Caminho do Modelo Decodificador",
21
+ "Disabled": "Desativado",
22
+ "Enable Initial Prompt": "Habilitar Prompt Inicial",
23
+ "Enable Reference Audio": "Habilitar Áudio de Referência",
24
+ "English": "Inglês",
25
+ "Japanese": "Japonês",
26
+ "Chinese": "Chinês",
27
+ "Portuguese": "Português",
28
+ "Spanish": "Espanhol",
29
+ "Error Message": "Mensagem de Erro",
30
+ "Faster Whisper, Up to 5g GPU memory usage": "Faster Whisper (Usa até 5 GB de vRAM)",
31
+ "File Preprocessing": "Pré-processamento de Arquivos",
32
+ "Generate": "Gerar",
33
+ "Generated Audio": "Áudio Gerado",
34
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "Se não houver texto correspondente ao áudio, utilize o ASR para assistência (formatos .txt ou .lab)",
35
+ "Infer interface is closed": "A interface de inferência foi fechada",
36
+ "Inference Configuration": "Configuração de Inferência",
37
+ "Inference Server Configuration": "Configuração do Servidor de Inferência",
38
+ "Inference Server Error": "Erro do Servidor de Inferência",
39
+ "Inferring interface is launched at {}": "A interface de inferência foi iniciada em {}",
40
+ "Initial Learning Rate": "Taxa de Aprendizagem Inicial",
41
+ "Initial Prompt": "Prompt Inicial",
42
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model.": "O prompt inicial pode fornecer orientação contextual ou específica de vocabulário para o modelo.",
43
+ "Input Audio & Source Path for Transcription": "Entrada de Áudio/Caminho de Origem para Transcrição",
44
+ "Input Text": "Texto de Entrada",
45
+ "Invalid path: {}": "Caminho inválido: {}",
46
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "Para GPUs Nvidia é recomendado usar CUDA. Se não tiver uma GPU Nvidia, use CPU",
47
+ "Iterative Prompt Length, 0 means off": "Comprimento do Prompt Iterativo (0 = desativado)",
48
+ "LLAMA Configuration": "Configuração do LLAMA",
49
+ "LLAMA Model Config": "Configuração do Modelo LLAMA",
50
+ "LLAMA Model Path": "Caminho do Modelo LLAMA",
51
+ "Labeling Device": "Dispositivo de Rotulagem",
52
+ "LoRA Model to be merged": "Modelo LoRA para mesclagem",
53
+ "Maximum Length per Sample": "Comprimento Máximo por Amostra",
54
+ "Maximum Training Steps": "Etapas Máximas de Treinamento",
55
+ "Maximum tokens per batch, 0 means no limit": "Número máximo de tokens por lote, 0 significa sem limite",
56
+ "Merge": "Mesclar",
57
+ "Merge LoRA": "Mesclar LoRA",
58
+ "Merge successfully": "Mesclado com sucesso",
59
+ "Model Output Path": "Caminho de Saída do Modelo",
60
+ "Model Quantization": "Quantização do Modelo",
61
+ "Model Size": "Tamanho do Modelo",
62
+ "Move": "Mover",
63
+ "Move files successfully": "Arquivos movidos com sucesso",
64
+ "No audio generated, please check the input text.": "Nenhum áudio gerado, verifique o texto de entrada.",
65
+ "No selected options": "Nenhuma opção selecionada",
66
+ "Normalization Result Preview (Currently Only Chinese)": "Pré-visualização do Resultado da Normalização (Atualmente Apenas Chinês)",
67
+ "Number of Workers": "Número de Processos",
68
+ "Open Inference Server": "Abrir Servidor de Inferência",
69
+ "Open Labeler WebUI": "Abrir WebUI de Rotulagem",
70
+ "Open Tensorboard": "Abrir Tensorboard",
71
+ "Opened labeler in browser": "WebUI de rotulagem aberta no navegador",
72
+ "Optional Label Language": "Idioma do Rótulo (Opcional)",
73
+ "Optional online ver": "Versão online (opcional)",
74
+ "Output Path": "Caminho de Saída",
75
+ "Path error, please check the model file exists in the corresponding path": "Erro de caminho, verifique se o arquivo do modelo existe no caminho correspondente",
76
+ "Post-quantification Precision": "Precisão Pós-quantização",
77
+ "Precision": "Precisão",
78
+ "Probability of applying Speaker Condition": "Probabilidade de Aplicar Condição de Orador",
79
+ "Put your text here.": "Insira seu texto aqui.",
80
+ "Quantify": "Quantizar",
81
+ "Quantify successfully": "Quantizado com sucesso",
82
+ "Realtime Transform Text": "Transformar Texto em Tempo Real",
83
+ "Reference Audio": "Áudio de Referência",
84
+ "Reference Text": "Texto de Referência",
85
+ "warning": "Aviso",
86
+ "Pre-processing begins...": "O pré-processamento começou!",
87
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "O código relacionado e os pesos são licenciados sob a Licença CC BY-NC-SA 4.0.",
88
+ "Remove Selected Data": "Remover Dados Selecionados",
89
+ "Removed path successfully!": "Caminho removido com sucesso!",
90
+ "Repetition Penalty": "Penalidade de Repetição",
91
+ "Save model every n steps": "Salvar modelo a cada n etapas",
92
+ "Select LLAMA ckpt": "Selecionar .ckpt do LLAMA",
93
+ "Select source file processing method": "Escolha como processar o arquivo de origem",
94
+ "Select the model to be trained (Depending on the Tab page you are on)": "Selecione o modelo para o treinamento (dependendo da aba em que você está)",
95
+ "Selected: {}": "Selecionado: {}",
96
+ "Speaker is identified by the folder name": "O orador é identificado pelo nome da pasta",
97
+ "Start Training": "Iniciar Treinamento",
98
+ "Streaming Audio": "Áudio em Streaming",
99
+ "Streaming Generate": "Geração em Streaming",
100
+ "Tensorboard Host": "Host do Tensorboard",
101
+ "Tensorboard Log Path": "Caminho de Log do Tensorboard",
102
+ "Tensorboard Port": "Porta do Tensorboard",
103
+ "Tensorboard interface is closed": "A interface do Tensorboard está fechada",
104
+ "Tensorboard interface is launched at {}": "A interface do Tensorboard foi iniciada em {}",
105
+ "Text Normalization": "Normalização de Texto",
106
+ "Text is too long, please keep it under {} characters.": "O texto é muito longo. Mantenha-o com menos de {} caracteres.",
107
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase": "Quanto menor a precisão quantitativa, mais a eficácia pode diminuir, mas maior será o aumento da eficiência",
108
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "O caminho da pasta de entrada à esquerda ou a lista de arquivos. Independentemente de estar marcada ou não, ela será utilizada para o treinamento subsequente nesta lista.",
109
+ "Training Configuration": "Configuração de Treinamento",
110
+ "Training Error": "Erro de Treinamento",
111
+ "Training stopped": "Treinamento interrompido!",
112
+ "Type the path or select from the dropdown": "Digite o caminho ou selecione no menu suspenso",
113
+ "Use LoRA": "Usar LoRA",
114
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "O uso de LoRAs pode economizar memória da GPU, mas também pode reduzir a qualidade",
115
+ "Use filelist": "Usar lista de arquivos",
116
+ "VQGAN Configuration": "Configuração do VQGAN",
117
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "Visualizar o status da pasta de pré-processamento (use o controle deslizante para controlar a profundidade da árvore)",
118
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "Não nos responsabilizamos por qualquer uso indevido do modelo. Por favor, considere as leis e regulamentações locais antes de usá-lo.",
119
+ "WebUI Host": "Host da WebUI",
120
+ "WebUI Port": "Porta da WebUI",
121
+ "Whisper Model": "Modelo Whisper",
122
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "Você pode encontrar o código fonte [aqui](https://github.com/fishaudio/fish-speech) e os modelos [aqui](https://huggingface.co/fishaudio/fish-speech-1).",
123
+ "auto": "automático",
124
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "bf16-true é recomendado para GPUs da série 30+, 16-mixed é recomendado para GPUs da série 10+",
125
+ "latest": "mais recente",
126
+ "new": "novo",
127
+ "This audio introduces the basic concepts and applications of artificial intelligence and machine learning.": "Este áudio introduz os conceitos básicos e aplicações de inteligência artificial e aprendizado de máquina.",
128
+ "You don't need to train this model!": "Não é necessário treinar este modelo!",
129
+ "Yes": "Sim",
130
+ "No": "Não",
131
+ "version:": "versão:",
132
+ "author:": "autor:"
133
+ }
fish_speech/i18n/locale/zh_CN.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "16-mixed is recommended for 10+ series GPU": "10+ 系列 GPU 建议使用 16-mixed",
3
+ "5 to 10 seconds of reference audio, useful for specifying speaker.": "5 到 10 秒的参考音频,适用于指定音色。",
4
+ "A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).": "由 [Fish Audio](https://fish.audio) 研发的基于 VQ-GAN 和 Llama 的多语种语音合成.",
5
+ "Accumulate Gradient Batches": "梯度累积批次",
6
+ "Add to Processing Area": "加入处理区",
7
+ "Added path successfully!": "添加路径成功!",
8
+ "Advanced Config": "高级参数",
9
+ "Base LLAMA Model": "基础 LLAMA 模型",
10
+ "Batch Inference": "批量推理",
11
+ "Batch Size": "批次大小",
12
+ "Changing with the Model Path": "随模型路径变化",
13
+ "Chinese": "中文",
14
+ "Compile Model": "编译模型",
15
+ "Compile the model can significantly reduce the inference time, but will increase cold start time": "编译模型可以显著减少推理时间,但会增加冷启动时间",
16
+ "Copy": "复制",
17
+ "Data Preprocessing": "数据预处理",
18
+ "Data Preprocessing Path": "数据预处理路径",
19
+ "Data Source": "数据源",
20
+ "Decoder Model Config": "解码器模型配置",
21
+ "Decoder Model Path": "解码器模型路径",
22
+ "Disabled": "禁用",
23
+ "Enable Reference Audio": "启用参考音频",
24
+ "English": "英文",
25
+ "Error Message": "错误信息",
26
+ "File Preprocessing": "文件预处理",
27
+ "Generate": "生成",
28
+ "Generated Audio": "音频",
29
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format": "如果音频没有对应的文本,可以应用 ASR 辅助,支持 .txt 或 .lab 格式",
30
+ "Infer interface is closed": "推理界面已关闭",
31
+ "Inference Configuration": "推理配置",
32
+ "Inference Server Configuration": "推理服务器配置",
33
+ "Inference Server Error": "推理服务器错误",
34
+ "Inferring interface is launched at {}": "推理界面已在 {} 上启动",
35
+ "Initial Learning Rate": "初始学习率",
36
+ "Input Audio & Source Path for Transcription": "输入音频和转录源路径",
37
+ "Input Text": "输入文本",
38
+ "Invalid path: {}": "无效路径: {}",
39
+ "It is recommended to use CUDA, if you have low configuration, use CPU": "建议使用 CUDA,如果配置较低,使用 CPU",
40
+ "Iterative Prompt Length, 0 means off": "迭代提示长度,0 表示关闭",
41
+ "Japanese": "日文",
42
+ "LLAMA Configuration": "LLAMA 配置",
43
+ "LLAMA Model Config": "LLAMA 模型配置",
44
+ "LLAMA Model Path": "LLAMA 模型路径",
45
+ "Labeling Device": "标注加速设备",
46
+ "LoRA Model to be merged": "要合并的 LoRA 模型",
47
+ "Maximum Audio Duration": "最大音频时长",
48
+ "Maximum Length per Sample": "每个样本的最大长度",
49
+ "Maximum Training Steps": "最大训练步数",
50
+ "Maximum tokens per batch, 0 means no limit": "每批最大令牌数,0 表示无限制",
51
+ "Merge": "合并",
52
+ "Merge LoRA": "合并 LoRA",
53
+ "Merge successfully": "合并成功",
54
+ "Minimum Audio Duration": "最小音频时长",
55
+ "Model Output Path": "模型输出路径",
56
+ "Model Size": "模型规模",
57
+ "Move": "移动",
58
+ "Move files successfully": "移动文件成功",
59
+ "No audio generated, please check the input text.": "没有生成音频,请检查输入文本.",
60
+ "No selected options": "没有选择的选项",
61
+ "Number of Workers": "数据加载进程数",
62
+ "Open Inference Server": "打开推理服务器",
63
+ "Open Labeler WebUI": "打开标注工具",
64
+ "Open Tensorboard": "打开 Tensorboard",
65
+ "Opened labeler in browser": "在浏览器中打开标注工具",
66
+ "Optional Label Language": "[可选] 标注语言",
67
+ "Optional online ver": "[可选] 使用在线版",
68
+ "Output Path": "输出路径",
69
+ "Path error, please check the model file exists in the corresponding path": "路径错误,请检查模型文件是否存在于相应路径",
70
+ "Precision": "精度",
71
+ "Probability of applying Speaker Condition": "应用说话人条件的概率",
72
+ "Put your text here.": "在此处输入文本.",
73
+ "Reference Audio": "参考音频",
74
+ "Reference Text": "参考文本",
75
+ "Related code and weights are released under CC BY-NC-SA 4.0 License.": "相关代码和权重使用 CC BY-NC-SA 4.0 许可证发布.",
76
+ "Remove Selected Data": "移除选中数据",
77
+ "Removed path successfully!": "移除路径成功!",
78
+ "Repetition Penalty": "重复惩罚",
79
+ "Save model every n steps": "每 n 步保存模型",
80
+ "Select LLAMA ckpt": "选择 LLAMA 检查点",
81
+ "Select VITS ckpt": "选择 VITS 检查点",
82
+ "Select VQGAN ckpt": "选择 VQGAN 检查点",
83
+ "Select source file processing method": "选择源文件处理方法",
84
+ "Select the model to be trained (Depending on the Tab page you are on)": "根据您所在的选项卡页面选择要训练的模型",
85
+ "Selected: {}": "已选择: {}",
86
+ "Speaker": "说话人",
87
+ "Speaker is identified by the folder name": "自动根据父目录名称识别说话人",
88
+ "Start Training": "开始训练",
89
+ "Streaming Audio": "流式音频",
90
+ "Streaming Generate": "流式合成",
91
+ "Tensorboard Host": "Tensorboard 监听地址",
92
+ "Tensorboard Log Path": "Tensorboard 日志路径",
93
+ "Tensorboard Port": "Tensorboard 端口",
94
+ "Tensorboard interface is closed": "Tensorboard 界面已关闭",
95
+ "Tensorboard interface is launched at {}": "Tensorboard 界面已在 {} 上启动",
96
+ "Text is too long, please keep it under {} characters.": "文本太长,请保持在 {} 个字符以内.",
97
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list.": "左侧输入文件夹的路径或文件列表。无论是否选中,都将在此列表中用于后续训练.",
98
+ "Training Configuration": "训练配置",
99
+ "Training Error": "训练错误",
100
+ "Training stopped": "训练已停止",
101
+ "Type name of the speaker": "输入说话人的名称",
102
+ "Type the path or select from the dropdown": "输入路径或从下拉菜单中选择",
103
+ "Use LoRA": "使用 LoRA",
104
+ "Use LoRA can save GPU memory, but may reduce the quality of the model": "使用 LoRA 可以节省 GPU 内存,但可能会降低模型质量",
105
+ "Use filelist": "使用文件列表",
106
+ "Use large for 10G+ GPU, medium for 5G, small for 2G": "10G+ GPU 使用 large, 5G 使用 medium, 2G 使用 small",
107
+ "VITS Configuration": "VITS 配置",
108
+ "VQGAN Configuration": "VQGAN 配置",
109
+ "Validation Batch Size": "验证批次大小",
110
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)": "查看预处理文件夹的状态 (使用滑块控制树的深度)",
111
+ "We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.": "我们不对模型的任何滥用负责,请在使用之前考虑您当地的法律法规.",
112
+ "WebUI Host": "WebUI 监听地址",
113
+ "WebUI Port": "WebUI 端口",
114
+ "Whisper Model": "Whisper 模型",
115
+ "You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1).": "你可以在 [这里](https://github.com/fishaudio/fish-speech) 找到源代码和 [这里](https://huggingface.co/fishaudio/fish-speech-1) 找到模型.",
116
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU": "30+ 系列 GPU 建议使用 bf16-true, 10+ 系列 GPU 建议使用 16-mixed",
117
+ "latest": "最近的检查点",
118
+ "new": "创建新的检查点",
119
+ "Realtime Transform Text": "实时规范化文本",
120
+ "Normalization Result Preview (Currently Only Chinese)": "规范化结果预览",
121
+ "Text Normalization": "文本规范化"
122
+ }
fish_speech/i18n/scan.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import glob
3
+ import json
4
+ from collections import OrderedDict
5
+ from pathlib import Path
6
+
7
+ from loguru import logger
8
+
9
+ from .core import DEFAULT_LANGUAGE, I18N_FILE_PATH
10
+
11
+
12
+ def extract_i18n_strings(node):
13
+ i18n_strings = []
14
+
15
+ if (
16
+ isinstance(node, ast.Call)
17
+ and isinstance(node.func, ast.Name)
18
+ and node.func.id == "i18n"
19
+ ):
20
+ for arg in node.args:
21
+ if isinstance(arg, ast.Str):
22
+ i18n_strings.append(arg.s)
23
+
24
+ for child_node in ast.iter_child_nodes(node):
25
+ i18n_strings.extend(extract_i18n_strings(child_node))
26
+
27
+ return i18n_strings
28
+
29
+
30
+ # scan the directory for all .py files (recursively)
31
+ # for each file, parse the code into an AST
32
+ # for each AST, extract the i18n strings
33
+
34
+ strings = []
35
+ folders = ["fish_speech", "tools"]
36
+ # for filename in glob.iglob("**/*.py", recursive=True):
37
+ for folder in folders:
38
+ for f in Path(folder).rglob("*.py"):
39
+ code = f.read_text(encoding="utf-8")
40
+ if "i18n(" in code:
41
+ tree = ast.parse(code)
42
+ i18n_strings = extract_i18n_strings(tree)
43
+ logger.info(f"Found {len(i18n_strings)} i18n strings in {f}")
44
+ strings.extend(i18n_strings)
45
+
46
+ code_keys = set(strings)
47
+ logger.info(f"Total unique: {len(code_keys)}")
48
+
49
+
50
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
51
+ with open(standard_file, "r", encoding="utf-8") as f:
52
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
53
+ standard_keys = set(standard_data.keys())
54
+
55
+ # Define the standard file name
56
+ unused_keys = standard_keys - code_keys
57
+ logger.info(f"Found {len(unused_keys)} unused keys in {standard_file}")
58
+ for unused_key in unused_keys:
59
+ logger.info(f"\t{unused_key}")
60
+
61
+ missing_keys = code_keys - standard_keys
62
+ logger.info(f"Found {len(missing_keys)} missing keys in {standard_file}")
63
+ for missing_key in missing_keys:
64
+ logger.info(f"\t{missing_key}")
65
+
66
+ code_keys_dict = OrderedDict()
67
+ for s in strings:
68
+ code_keys_dict[s] = s
69
+
70
+ # write back
71
+ with open(standard_file, "w", encoding="utf-8") as f:
72
+ json.dump(code_keys_dict, f, ensure_ascii=False, indent=4, sort_keys=True)
73
+ f.write("\n")
74
+
75
+ logger.info(f"Updated {standard_file}")
76
+
77
+
78
+ # Define the standard file name
79
+ standard_file = I18N_FILE_PATH / f"{DEFAULT_LANGUAGE}.json"
80
+
81
+ # Find all JSON files in the directory
82
+ dir_path = I18N_FILE_PATH
83
+ languages = [f for f in dir_path.glob("*.json") if f.stem != DEFAULT_LANGUAGE]
84
+
85
+ # Load the standard file
86
+ with open(standard_file, "r", encoding="utf-8") as f:
87
+ standard_data = json.load(f, object_pairs_hook=OrderedDict)
88
+
89
+ # Loop through each language file
90
+ for lang_file in languages:
91
+ # Load the language file
92
+ with open(lang_file, "r", encoding="utf-8") as f:
93
+ lang_data = json.load(f, object_pairs_hook=OrderedDict)
94
+
95
+ # Find the difference between the language file and the standard file
96
+ diff = set(standard_data.keys()) - set(lang_data.keys())
97
+
98
+ miss = set(lang_data.keys()) - set(standard_data.keys())
99
+
100
+ # Add any missing keys to the language file
101
+ for key in diff:
102
+ lang_data[key] = "#!" + key
103
+ logger.info(f"Added missing key: {key} to {lang_file}")
104
+
105
+ # Del any extra keys to the language file
106
+ for key in miss:
107
+ del lang_data[key]
108
+ logger.info(f"Del extra key: {key} from {lang_file}")
109
+
110
+ # Sort the keys of the language file to match the order of the standard file
111
+ lang_data = OrderedDict(
112
+ sorted(lang_data.items(), key=lambda x: list(standard_data.keys()).index(x[0]))
113
+ )
114
+
115
+ # Save the updated language file
116
+ with open(lang_file, "w", encoding="utf-8") as f:
117
+ json.dump(lang_data, f, ensure_ascii=False, indent=4, sort_keys=True)
118
+ f.write("\n")
119
+
120
+ logger.info(f"Updated {lang_file}")
121
+
122
+ logger.info("Done")
fish_speech/models/text2semantic/llama.py CHANGED
@@ -1,5 +1,6 @@
1
  import json
2
  import math
 
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
  from typing import Optional
@@ -370,6 +371,32 @@ class BaseTransformer(nn.Module):
370
  weights = torch.load(
371
  Path(path) / "model.pth", map_location="cpu", mmap=True
372
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  err = model.load_state_dict(weights, strict=False, assign=True)
374
  log.info(f"Loaded weights with error: {err}")
375
 
 
1
  import json
2
  import math
3
+ from collections import OrderedDict
4
  from dataclasses import dataclass
5
  from pathlib import Path
6
  from typing import Optional
 
371
  weights = torch.load(
372
  Path(path) / "model.pth", map_location="cpu", mmap=True
373
  )
374
+
375
+ if "state_dict" in weights:
376
+ logger.warning(
377
+ "Using a TextToSemantic LightningModule checkpoint, "
378
+ "please make sure it is a full model, not a LoRA model."
379
+ )
380
+ weights = weights["state_dict"]
381
+
382
+ if next(iter(weights.keys())).startswith("model."):
383
+ logger.info(
384
+ f"Remove prefix 'model.' created by TextToSemantic LightningModule from keys"
385
+ )
386
+ new_weights = OrderedDict()
387
+ for k, v in weights.items():
388
+ new_weights[k.replace("model.", "")] = v
389
+ weights = new_weights
390
+
391
+ # Verify the name and shape of parameters since strict=False in load_state_dict.
392
+ for k, v in model.named_parameters():
393
+ if k not in weights:
394
+ logger.warning(f"No weight for {k}")
395
+ elif v.shape != weights[k].shape:
396
+ logger.warning(
397
+ f"Shape mismatch for {k}: {v.shape} vs {weights[k].shape}"
398
+ )
399
+
400
  err = model.load_state_dict(weights, strict=False, assign=True)
401
  log.info(f"Loaded weights with error: {err}")
402
 
fish_speech/models/vqgan/__init__.py CHANGED
@@ -1,3 +0,0 @@
1
- from .lit_module import VQGAN
2
-
3
- __all__ = ["VQGAN"]
 
 
 
 
fish_speech/models/vqgan/modules/firefly.py CHANGED
@@ -1,25 +1,26 @@
1
- # A inference only version of the FireflyGAN model
2
-
3
  import math
4
  from functools import partial
5
  from math import prod
6
  from typing import Callable
7
 
8
- import numpy as np
9
  import torch
10
  import torch.nn.functional as F
11
  from torch import nn
12
- from torch.nn import Conv1d
13
  from torch.nn.utils.parametrizations import weight_norm
14
  from torch.nn.utils.parametrize import remove_parametrizations
15
  from torch.utils.checkpoint import checkpoint
16
 
17
- from fish_speech.models.vqgan.utils import sequence_mask
 
 
 
 
 
18
 
19
 
20
  def init_weights(m, mean=0.0, std=0.01):
21
  classname = m.__class__.__name__
22
- if classname.find("Conv") != -1:
23
  m.weight.data.normal_(mean, std)
24
 
25
 
@@ -27,78 +28,141 @@ def get_padding(kernel_size, dilation=1):
27
  return (kernel_size * dilation - dilation) // 2
28
 
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  class ResBlock1(torch.nn.Module):
31
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
32
  super().__init__()
33
 
34
  self.convs1 = nn.ModuleList(
35
  [
36
- weight_norm(
37
- Conv1d(
38
- channels,
39
- channels,
40
- kernel_size,
41
- 1,
42
- dilation=dilation[0],
43
- padding=get_padding(kernel_size, dilation[0]),
44
- )
45
- ),
46
- weight_norm(
47
- Conv1d(
48
- channels,
49
- channels,
50
- kernel_size,
51
- 1,
52
- dilation=dilation[1],
53
- padding=get_padding(kernel_size, dilation[1]),
54
- )
55
- ),
56
- weight_norm(
57
- Conv1d(
58
- channels,
59
- channels,
60
- kernel_size,
61
- 1,
62
- dilation=dilation[2],
63
- padding=get_padding(kernel_size, dilation[2]),
64
- )
65
- ),
66
  ]
67
  )
68
  self.convs1.apply(init_weights)
69
 
70
  self.convs2 = nn.ModuleList(
71
  [
72
- weight_norm(
73
- Conv1d(
74
- channels,
75
- channels,
76
- kernel_size,
77
- 1,
78
- dilation=1,
79
- padding=get_padding(kernel_size, 1),
80
- )
81
- ),
82
- weight_norm(
83
- Conv1d(
84
- channels,
85
- channels,
86
- kernel_size,
87
- 1,
88
- dilation=1,
89
- padding=get_padding(kernel_size, 1),
90
- )
91
- ),
92
- weight_norm(
93
- Conv1d(
94
- channels,
95
- channels,
96
- kernel_size,
97
- 1,
98
- dilation=1,
99
- padding=get_padding(kernel_size, 1),
100
- )
101
- ),
102
  ]
103
  )
104
  self.convs2.apply(init_weights)
@@ -119,7 +183,7 @@ class ResBlock1(torch.nn.Module):
119
  remove_parametrizations(conv, tensor_name="weight")
120
 
121
 
122
- class ParralelBlock(nn.Module):
123
  def __init__(
124
  self,
125
  channels: int,
@@ -153,7 +217,6 @@ class HiFiGANGenerator(nn.Module):
153
  resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
154
  num_mels: int = 128,
155
  upsample_initial_channel: int = 512,
156
- use_template: bool = True,
157
  pre_conv_kernel_size: int = 7,
158
  post_conv_kernel_size: int = 7,
159
  post_activation: Callable = partial(nn.SiLU, inplace=True),
@@ -164,84 +227,50 @@ class HiFiGANGenerator(nn.Module):
164
  prod(upsample_rates) == hop_length
165
  ), f"hop_length must be {prod(upsample_rates)}"
166
 
167
- self.conv_pre = weight_norm(
168
- nn.Conv1d(
169
- num_mels,
170
- upsample_initial_channel,
171
- pre_conv_kernel_size,
172
- 1,
173
- padding=get_padding(pre_conv_kernel_size),
174
- )
175
- )
176
 
177
  self.num_upsamples = len(upsample_rates)
178
  self.num_kernels = len(resblock_kernel_sizes)
179
 
180
  self.noise_convs = nn.ModuleList()
181
- self.use_template = use_template
182
  self.ups = nn.ModuleList()
183
 
184
  for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
185
- c_cur = upsample_initial_channel // (2 ** (i + 1))
186
  self.ups.append(
187
- weight_norm(
188
- nn.ConvTranspose1d(
189
- upsample_initial_channel // (2**i),
190
- upsample_initial_channel // (2 ** (i + 1)),
191
- k,
192
- u,
193
- padding=(k - u) // 2,
194
- )
195
- )
196
  )
197
 
198
- if not use_template:
199
- continue
200
-
201
- if i + 1 < len(upsample_rates):
202
- stride_f0 = np.prod(upsample_rates[i + 1 :])
203
- self.noise_convs.append(
204
- Conv1d(
205
- 1,
206
- c_cur,
207
- kernel_size=stride_f0 * 2,
208
- stride=stride_f0,
209
- padding=stride_f0 // 2,
210
- )
211
- )
212
- else:
213
- self.noise_convs.append(Conv1d(1, c_cur, kernel_size=1))
214
-
215
  self.resblocks = nn.ModuleList()
216
  for i in range(len(self.ups)):
217
  ch = upsample_initial_channel // (2 ** (i + 1))
218
  self.resblocks.append(
219
- ParralelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
220
  )
221
 
222
  self.activation_post = post_activation()
223
- self.conv_post = weight_norm(
224
- nn.Conv1d(
225
- ch,
226
- 1,
227
- post_conv_kernel_size,
228
- 1,
229
- padding=get_padding(post_conv_kernel_size),
230
- )
231
- )
232
  self.ups.apply(init_weights)
233
  self.conv_post.apply(init_weights)
234
 
235
- def forward(self, x, template=None):
236
  x = self.conv_pre(x)
237
 
238
  for i in range(self.num_upsamples):
239
  x = F.silu(x, inplace=True)
240
  x = self.ups[i](x)
241
 
242
- if self.use_template:
243
- x = x + self.noise_convs[i](template)
244
-
245
  if self.training and self.checkpointing:
246
  x = checkpoint(
247
  self.resblocks[i],
@@ -364,11 +393,11 @@ class ConvNeXtBlock(nn.Module):
364
  ):
365
  super().__init__()
366
 
367
- self.dwconv = nn.Conv1d(
368
  dim,
369
  dim,
370
  kernel_size=kernel_size,
371
- padding=int(dilation * (kernel_size - 1) / 2),
372
  groups=dim,
373
  ) # depthwise conv
374
  self.norm = LayerNorm(dim, eps=1e-6)
@@ -421,12 +450,13 @@ class ConvNeXtEncoder(nn.Module):
421
 
422
  self.downsample_layers = nn.ModuleList()
423
  stem = nn.Sequential(
424
- nn.Conv1d(
425
  input_channels,
426
  dims[0],
427
- kernel_size=kernel_size,
428
- padding=kernel_size // 2,
429
- padding_mode="zeros",
 
430
  ),
431
  LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
432
  )
@@ -491,6 +521,7 @@ class FireflyArchitecture(nn.Module):
491
  self.head = head
492
  self.quantizer = quantizer
493
  self.spec_transform = spec_transform
 
494
 
495
  def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
496
  if self.spec_transform is not None:
@@ -528,25 +559,30 @@ class FireflyArchitecture(nn.Module):
528
 
529
  # Encode
530
  encoded_features = self.backbone(mels) * mel_masks_float_conv
531
- feature_lengths = mel_lengths // math.prod(self.quantizer.downsample_factor)
532
 
533
  return self.quantizer.encode(encoded_features), feature_lengths
534
 
535
  def decode(self, indices, feature_lengths) -> torch.Tensor:
536
- factor = math.prod(self.quantizer.downsample_factor)
537
- mel_masks = sequence_mask(feature_lengths * factor, indices.shape[2] * factor)
 
 
538
  mel_masks_float_conv = mel_masks[:, None, :].float()
 
 
 
539
 
540
  audio_masks = sequence_mask(
541
- feature_lengths * factor * self.spec_transform.hop_length,
542
- indices.shape[2] * factor * self.spec_transform.hop_length,
543
  )
544
  audio_masks_float_conv = audio_masks[:, None, :].float()
545
 
546
  z = self.quantizer.decode(indices) * mel_masks_float_conv
547
  x = self.head(z) * audio_masks_float_conv
548
 
549
- return x
550
 
551
  def remove_parametrizations(self):
552
  if hasattr(self.backbone, "remove_parametrizations"):
@@ -558,68 +594,3 @@ class FireflyArchitecture(nn.Module):
558
  @property
559
  def device(self):
560
  return next(self.parameters()).device
561
-
562
-
563
- class FireflyBase(nn.Module):
564
- def __init__(self, ckpt_path: str = None, pretrained: bool = True):
565
- super().__init__()
566
-
567
- self.backbone = ConvNeXtEncoder(
568
- input_channels=128,
569
- depths=[3, 3, 9, 3],
570
- dims=[128, 256, 384, 512],
571
- drop_path_rate=0.2,
572
- kernel_size=7,
573
- )
574
-
575
- self.head = HiFiGANGenerator(
576
- hop_length=512,
577
- upsample_rates=[8, 8, 2, 2, 2],
578
- upsample_kernel_sizes=[16, 16, 4, 4, 4],
579
- resblock_kernel_sizes=[3, 7, 11],
580
- resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
581
- num_mels=512,
582
- upsample_initial_channel=512,
583
- use_template=False,
584
- pre_conv_kernel_size=13,
585
- post_conv_kernel_size=13,
586
- )
587
-
588
- if ckpt_path is not None:
589
- state_dict = torch.load(ckpt_path, map_location="cpu")
590
- elif pretrained:
591
- state_dict = torch.hub.load_state_dict_from_url(
592
- "https://github.com/fishaudio/vocoder/releases/download/1.0.0/firefly-gan-base-generator.ckpt",
593
- map_location="cpu",
594
- model_dir="checkpoints",
595
- )
596
-
597
- if "state_dict" in state_dict:
598
- state_dict = state_dict["state_dict"]
599
-
600
- if any("generator." in k for k in state_dict):
601
- state_dict = {
602
- k.replace("generator.", ""): v
603
- for k, v in state_dict.items()
604
- if "generator." in k
605
- }
606
-
607
- self.load_state_dict(state_dict, strict=True)
608
- self.head.remove_parametrizations()
609
-
610
- @torch.no_grad()
611
- def forward(self, x: torch.Tensor) -> torch.Tensor:
612
- x = self.backbone(x)
613
- x = self.head(x)
614
- if x.ndim == 2:
615
- x = x[:, None, :]
616
- return x
617
-
618
-
619
- if __name__ == "__main__":
620
- model = FireflyBase()
621
- model.eval()
622
- x = torch.randn(1, 128, 128)
623
- with torch.no_grad():
624
- y = model(x)
625
- print(y.shape)
 
 
 
1
  import math
2
  from functools import partial
3
  from math import prod
4
  from typing import Callable
5
 
 
6
  import torch
7
  import torch.nn.functional as F
8
  from torch import nn
 
9
  from torch.nn.utils.parametrizations import weight_norm
10
  from torch.nn.utils.parametrize import remove_parametrizations
11
  from torch.utils.checkpoint import checkpoint
12
 
13
+
14
+ def sequence_mask(length, max_length=None):
15
+ if max_length is None:
16
+ max_length = length.max()
17
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
18
+ return x.unsqueeze(0) < length.unsqueeze(1)
19
 
20
 
21
  def init_weights(m, mean=0.0, std=0.01):
22
  classname = m.__class__.__name__
23
+ if classname.find("Conv1D") != -1:
24
  m.weight.data.normal_(mean, std)
25
 
26
 
 
28
  return (kernel_size * dilation - dilation) // 2
29
 
30
 
31
+ def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
32
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
33
+ padding_left, padding_right = paddings
34
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
35
+ assert (padding_left + padding_right) <= x.shape[-1]
36
+ end = x.shape[-1] - padding_right
37
+ return x[..., padding_left:end]
38
+
39
+
40
+ def get_extra_padding_for_conv1d(
41
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
42
+ ) -> int:
43
+ """See `pad_for_conv1d`."""
44
+ length = x.shape[-1]
45
+ n_frames = (length - kernel_size + padding_total) / stride + 1
46
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
47
+ return ideal_length - length
48
+
49
+
50
+ def pad1d(
51
+ x: torch.Tensor,
52
+ paddings: tuple[int, int],
53
+ mode: str = "zeros",
54
+ value: float = 0.0,
55
+ ):
56
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
57
+ If this is the case, we insert extra 0 padding to the right
58
+ before the reflection happen.
59
+ """
60
+ length = x.shape[-1]
61
+ padding_left, padding_right = paddings
62
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
63
+ if mode == "reflect":
64
+ max_pad = max(padding_left, padding_right)
65
+ extra_pad = 0
66
+ if length <= max_pad:
67
+ extra_pad = max_pad - length + 1
68
+ x = F.pad(x, (0, extra_pad))
69
+ padded = F.pad(x, paddings, mode, value)
70
+ end = padded.shape[-1] - extra_pad
71
+ return padded[..., :end]
72
+ else:
73
+ return F.pad(x, paddings, mode, value)
74
+
75
+
76
+ class FishConvNet(nn.Module):
77
+ def __init__(
78
+ self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1
79
+ ):
80
+ super(FishConvNet, self).__init__()
81
+ self.conv = nn.Conv1d(
82
+ in_channels,
83
+ out_channels,
84
+ kernel_size,
85
+ stride=stride,
86
+ dilation=dilation,
87
+ groups=groups,
88
+ )
89
+ self.stride = stride
90
+ self.kernel_size = (kernel_size - 1) * dilation + 1
91
+ self.dilation = dilation
92
+
93
+ def forward(self, x):
94
+ pad = self.kernel_size - self.stride
95
+ extra_padding = get_extra_padding_for_conv1d(
96
+ x, self.kernel_size, self.stride, pad
97
+ )
98
+ x = pad1d(x, (pad, extra_padding), mode="constant", value=0)
99
+ return self.conv(x).contiguous()
100
+
101
+ def weight_norm(self, name="weight", dim=0):
102
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
103
+ return self
104
+
105
+ def remove_weight_norm(self):
106
+ self.conv = remove_parametrizations(self.conv)
107
+ return self
108
+
109
+
110
+ class FishTransConvNet(nn.Module):
111
+ def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
112
+ super(FishTransConvNet, self).__init__()
113
+ self.conv = nn.ConvTranspose1d(
114
+ in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
115
+ )
116
+ self.stride = stride
117
+ self.kernel_size = kernel_size
118
+
119
+ def forward(self, x):
120
+ x = self.conv(x)
121
+ pad = self.kernel_size - self.stride
122
+ padding_right = math.ceil(pad)
123
+ padding_left = pad - padding_right
124
+ x = unpad1d(x, (padding_left, padding_right))
125
+ return x.contiguous()
126
+
127
+ def weight_norm(self, name="weight", dim=0):
128
+ self.conv = weight_norm(self.conv, name=name, dim=dim)
129
+ return self
130
+
131
+ def remove_weight_norm(self):
132
+ self.conv = remove_parametrizations(self.conv)
133
+ return self
134
+
135
+
136
  class ResBlock1(torch.nn.Module):
137
  def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
138
  super().__init__()
139
 
140
  self.convs1 = nn.ModuleList(
141
  [
142
+ FishConvNet(
143
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
144
+ ).weight_norm(),
145
+ FishConvNet(
146
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
147
+ ).weight_norm(),
148
+ FishConvNet(
149
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
150
+ ).weight_norm(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  ]
152
  )
153
  self.convs1.apply(init_weights)
154
 
155
  self.convs2 = nn.ModuleList(
156
  [
157
+ FishConvNet(
158
+ channels, channels, kernel_size, stride=1, dilation=dilation[0]
159
+ ).weight_norm(),
160
+ FishConvNet(
161
+ channels, channels, kernel_size, stride=1, dilation=dilation[1]
162
+ ).weight_norm(),
163
+ FishConvNet(
164
+ channels, channels, kernel_size, stride=1, dilation=dilation[2]
165
+ ).weight_norm(),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  ]
167
  )
168
  self.convs2.apply(init_weights)
 
183
  remove_parametrizations(conv, tensor_name="weight")
184
 
185
 
186
+ class ParallelBlock(nn.Module):
187
  def __init__(
188
  self,
189
  channels: int,
 
217
  resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
218
  num_mels: int = 128,
219
  upsample_initial_channel: int = 512,
 
220
  pre_conv_kernel_size: int = 7,
221
  post_conv_kernel_size: int = 7,
222
  post_activation: Callable = partial(nn.SiLU, inplace=True),
 
227
  prod(upsample_rates) == hop_length
228
  ), f"hop_length must be {prod(upsample_rates)}"
229
 
230
+ self.conv_pre = FishConvNet(
231
+ num_mels,
232
+ upsample_initial_channel,
233
+ pre_conv_kernel_size,
234
+ stride=1,
235
+ ).weight_norm()
 
 
 
236
 
237
  self.num_upsamples = len(upsample_rates)
238
  self.num_kernels = len(resblock_kernel_sizes)
239
 
240
  self.noise_convs = nn.ModuleList()
 
241
  self.ups = nn.ModuleList()
242
 
243
  for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
 
244
  self.ups.append(
245
+ FishTransConvNet(
246
+ upsample_initial_channel // (2**i),
247
+ upsample_initial_channel // (2 ** (i + 1)),
248
+ k,
249
+ stride=u,
250
+ ).weight_norm()
 
 
 
251
  )
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  self.resblocks = nn.ModuleList()
254
  for i in range(len(self.ups)):
255
  ch = upsample_initial_channel // (2 ** (i + 1))
256
  self.resblocks.append(
257
+ ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
258
  )
259
 
260
  self.activation_post = post_activation()
261
+ self.conv_post = FishConvNet(
262
+ ch, 1, post_conv_kernel_size, stride=1
263
+ ).weight_norm()
 
 
 
 
 
 
264
  self.ups.apply(init_weights)
265
  self.conv_post.apply(init_weights)
266
 
267
+ def forward(self, x):
268
  x = self.conv_pre(x)
269
 
270
  for i in range(self.num_upsamples):
271
  x = F.silu(x, inplace=True)
272
  x = self.ups[i](x)
273
 
 
 
 
274
  if self.training and self.checkpointing:
275
  x = checkpoint(
276
  self.resblocks[i],
 
393
  ):
394
  super().__init__()
395
 
396
+ self.dwconv = FishConvNet(
397
  dim,
398
  dim,
399
  kernel_size=kernel_size,
400
+ # padding=int(dilation * (kernel_size - 1) / 2),
401
  groups=dim,
402
  ) # depthwise conv
403
  self.norm = LayerNorm(dim, eps=1e-6)
 
450
 
451
  self.downsample_layers = nn.ModuleList()
452
  stem = nn.Sequential(
453
+ FishConvNet(
454
  input_channels,
455
  dims[0],
456
+ kernel_size=7,
457
+ # padding=3,
458
+ # padding_mode="replicate",
459
+ # padding_mode="zeros",
460
  ),
461
  LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
462
  )
 
521
  self.head = head
522
  self.quantizer = quantizer
523
  self.spec_transform = spec_transform
524
+ self.downsample_factor = math.prod(self.quantizer.downsample_factor)
525
 
526
  def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
527
  if self.spec_transform is not None:
 
559
 
560
  # Encode
561
  encoded_features = self.backbone(mels) * mel_masks_float_conv
562
+ feature_lengths = mel_lengths // self.downsample_factor
563
 
564
  return self.quantizer.encode(encoded_features), feature_lengths
565
 
566
  def decode(self, indices, feature_lengths) -> torch.Tensor:
567
+ mel_masks = sequence_mask(
568
+ feature_lengths * self.downsample_factor,
569
+ indices.shape[2] * self.downsample_factor,
570
+ )
571
  mel_masks_float_conv = mel_masks[:, None, :].float()
572
+ audio_lengths = (
573
+ feature_lengths * self.downsample_factor * self.spec_transform.hop_length
574
+ )
575
 
576
  audio_masks = sequence_mask(
577
+ audio_lengths,
578
+ indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
579
  )
580
  audio_masks_float_conv = audio_masks[:, None, :].float()
581
 
582
  z = self.quantizer.decode(indices) * mel_masks_float_conv
583
  x = self.head(z) * audio_masks_float_conv
584
 
585
+ return x, audio_lengths
586
 
587
  def remove_parametrizations(self):
588
  if hasattr(self.backbone, "remove_parametrizations"):
 
594
  @property
595
  def device(self):
596
  return next(self.parameters()).device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/models/vqgan/modules/fsq.py CHANGED
@@ -6,7 +6,7 @@ import torch.nn.functional as F
6
  from einops import rearrange
7
  from vector_quantize_pytorch import GroupedResidualFSQ
8
 
9
- from .firefly import ConvNeXtBlock
10
 
11
 
12
  @dataclass
@@ -20,7 +20,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
20
  def __init__(
21
  self,
22
  input_dim: int = 512,
23
- n_codebooks: int = 1,
24
  n_groups: int = 1,
25
  levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
  downsample_factor: tuple[int] = (2, 2),
@@ -46,7 +46,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
46
  self.downsample = nn.Sequential(
47
  *[
48
  nn.Sequential(
49
- nn.Conv1d(
50
  all_dims[idx],
51
  all_dims[idx + 1],
52
  kernel_size=factor,
@@ -61,7 +61,7 @@ class DownsampleFiniteScalarQuantize(nn.Module):
61
  self.upsample = nn.Sequential(
62
  *[
63
  nn.Sequential(
64
- nn.ConvTranspose1d(
65
  all_dims[idx + 1],
66
  all_dims[idx],
67
  kernel_size=factor,
@@ -114,26 +114,3 @@ class DownsampleFiniteScalarQuantize(nn.Module):
114
  z_q = self.residual_fsq.get_output_from_indices(indices)
115
  z_q = self.upsample(z_q.mT)
116
  return z_q
117
-
118
- # def from_latents(self, latents: torch.Tensor):
119
- # z_q, z_p, codes = super().from_latents(latents)
120
- # z_q = self.upsample(z_q)
121
- # return z_q, z_p, codes
122
-
123
-
124
- if __name__ == "__main__":
125
- rvq = DownsampleFiniteScalarQuantize(
126
- n_codebooks=1,
127
- downsample_factor=(2, 2),
128
- )
129
- x = torch.randn(16, 512, 80)
130
-
131
- result = rvq(x)
132
- print(rvq)
133
- print(result.latents.shape, result.codes.shape, result.z.shape)
134
-
135
- # y = rvq.from_codes(result.codes)
136
- # print(y[0].shape)
137
-
138
- # y = rvq.from_latents(result.latents)
139
- # print(y[0].shape)
 
6
  from einops import rearrange
7
  from vector_quantize_pytorch import GroupedResidualFSQ
8
 
9
+ from .firefly import ConvNeXtBlock, FishConvNet, FishTransConvNet
10
 
11
 
12
  @dataclass
 
20
  def __init__(
21
  self,
22
  input_dim: int = 512,
23
+ n_codebooks: int = 9,
24
  n_groups: int = 1,
25
  levels: tuple[int] = (8, 5, 5, 5), # Approximate 2**10
26
  downsample_factor: tuple[int] = (2, 2),
 
46
  self.downsample = nn.Sequential(
47
  *[
48
  nn.Sequential(
49
+ FishConvNet(
50
  all_dims[idx],
51
  all_dims[idx + 1],
52
  kernel_size=factor,
 
61
  self.upsample = nn.Sequential(
62
  *[
63
  nn.Sequential(
64
+ FishTransConvNet(
65
  all_dims[idx + 1],
66
  all_dims[idx],
67
  kernel_size=factor,
 
114
  z_q = self.residual_fsq.get_output_from_indices(indices)
115
  z_q = self.upsample(z_q.mT)
116
  return z_q
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/scheduler.py CHANGED
@@ -4,11 +4,14 @@ import math
4
  def get_cosine_schedule_with_warmup_lr_lambda(
5
  current_step: int,
6
  *,
7
- num_warmup_steps: int,
8
  num_training_steps: int,
9
  num_cycles: float = 0.5,
10
  final_lr_ratio: float = 0.0,
11
  ):
 
 
 
12
  if current_step < num_warmup_steps:
13
  return float(current_step) / float(max(1, num_warmup_steps))
14
 
@@ -20,3 +23,18 @@ def get_cosine_schedule_with_warmup_lr_lambda(
20
  final_lr_ratio,
21
  0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
22
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  def get_cosine_schedule_with_warmup_lr_lambda(
5
  current_step: int,
6
  *,
7
+ num_warmup_steps: int | float,
8
  num_training_steps: int,
9
  num_cycles: float = 0.5,
10
  final_lr_ratio: float = 0.0,
11
  ):
12
+ if 0 < num_warmup_steps < 1: # float mode
13
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
14
+
15
  if current_step < num_warmup_steps:
16
  return float(current_step) / float(max(1, num_warmup_steps))
17
 
 
23
  final_lr_ratio,
24
  0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)),
25
  )
26
+
27
+
28
+ def get_constant_schedule_with_warmup_lr_lambda(
29
+ current_step: int,
30
+ *,
31
+ num_warmup_steps: int | float,
32
+ num_training_steps: int | None = None,
33
+ ):
34
+ if 0 < num_warmup_steps < 1: # float mode
35
+ num_warmup_steps = int(num_warmup_steps * num_training_steps)
36
+
37
+ if current_step < num_warmup_steps:
38
+ return float(current_step) / float(max(1, num_warmup_steps))
39
+
40
+ return 1.0
fish_speech/text/clean.py CHANGED
@@ -64,6 +64,6 @@ def clean_text(text):
64
 
65
  # Replace all chinese symbols with their english counterparts
66
  text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
67
- text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
68
 
69
  return text
 
64
 
65
  # Replace all chinese symbols with their english counterparts
66
  text = REPLACE_SYMBOL_REGEX.sub(lambda x: SYMBOLS_MAPPING[x.group()], text)
67
+ # text = REMOVE_UNKNOWN_SYMBOL_REGEX.sub("", text)
68
 
69
  return text
fish_speech/train.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from typing import Optional
3
 
4
  import hydra
@@ -7,6 +8,7 @@ import pyrootutils
7
  import torch
8
  from lightning import Callback, LightningDataModule, LightningModule, Trainer
9
  from lightning.pytorch.loggers import Logger
 
10
  from omegaconf import DictConfig, OmegaConf
11
 
12
  os.environ.pop("SLURM_NTASKS", None)
@@ -61,7 +63,9 @@ def train(cfg: DictConfig) -> tuple[dict, dict]:
61
 
62
  log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
63
  trainer: Trainer = hydra.utils.instantiate(
64
- cfg.trainer, callbacks=callbacks, logger=logger
 
 
65
  )
66
 
67
  object_dict = {
 
1
  import os
2
+ import sys
3
  from typing import Optional
4
 
5
  import hydra
 
8
  import torch
9
  from lightning import Callback, LightningDataModule, LightningModule, Trainer
10
  from lightning.pytorch.loggers import Logger
11
+ from lightning.pytorch.strategies import DDPStrategy
12
  from omegaconf import DictConfig, OmegaConf
13
 
14
  os.environ.pop("SLURM_NTASKS", None)
 
63
 
64
  log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
65
  trainer: Trainer = hydra.utils.instantiate(
66
+ cfg.trainer,
67
+ callbacks=callbacks,
68
+ logger=logger,
69
  )
70
 
71
  object_dict = {
fish_speech/utils/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
  from .braceexpand import braceexpand
 
2
  from .file import get_latest_checkpoint
3
  from .instantiators import instantiate_callbacks, instantiate_loggers
4
  from .logger import RankedLogger
@@ -18,4 +19,5 @@ __all__ = [
18
  "task_wrapper",
19
  "braceexpand",
20
  "get_latest_checkpoint",
 
21
  ]
 
1
  from .braceexpand import braceexpand
2
+ from .context import autocast_exclude_mps
3
  from .file import get_latest_checkpoint
4
  from .instantiators import instantiate_callbacks, instantiate_loggers
5
  from .logger import RankedLogger
 
19
  "task_wrapper",
20
  "braceexpand",
21
  "get_latest_checkpoint",
22
+ "autocast_exclude_mps",
23
  ]
fish_speech/utils/context.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import nullcontext
2
+
3
+ import torch
4
+
5
+
6
+ def autocast_exclude_mps(
7
+ device_type: str, dtype: torch.dtype
8
+ ) -> nullcontext | torch.autocast:
9
+ return (
10
+ nullcontext()
11
+ if torch.backends.mps.is_available()
12
+ else torch.autocast(device_type, dtype)
13
+ )
fish_speech/utils/file.py CHANGED
@@ -1,55 +1,5 @@
1
  import os
2
- from glob import glob
3
  from pathlib import Path
4
- from typing import Union
5
-
6
- from loguru import logger
7
- from natsort import natsorted
8
-
9
- AUDIO_EXTENSIONS = {
10
- ".mp3",
11
- ".wav",
12
- ".flac",
13
- ".ogg",
14
- ".m4a",
15
- ".wma",
16
- ".aac",
17
- ".aiff",
18
- ".aif",
19
- ".aifc",
20
- }
21
-
22
-
23
- def list_files(
24
- path: Union[Path, str],
25
- extensions: set[str] = None,
26
- recursive: bool = False,
27
- sort: bool = True,
28
- ) -> list[Path]:
29
- """List files in a directory.
30
-
31
- Args:
32
- path (Path): Path to the directory.
33
- extensions (set, optional): Extensions to filter. Defaults to None.
34
- recursive (bool, optional): Whether to search recursively. Defaults to False.
35
- sort (bool, optional): Whether to sort the files. Defaults to True.
36
-
37
- Returns:
38
- list: List of files.
39
- """
40
-
41
- if isinstance(path, str):
42
- path = Path(path)
43
-
44
- if not path.exists():
45
- raise FileNotFoundError(f"Directory {path} does not exist.")
46
-
47
- files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
48
-
49
- if sort:
50
- files = natsorted(files)
51
-
52
- return files
53
 
54
 
55
  def get_latest_checkpoint(path: Path | str) -> Path | None:
@@ -64,56 +14,3 @@ def get_latest_checkpoint(path: Path | str) -> Path | None:
64
  return None
65
 
66
  return ckpts[-1]
67
-
68
-
69
- def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
70
- """
71
- Load a Bert-VITS2 style filelist.
72
- """
73
-
74
- files = set()
75
- results = []
76
- count_duplicated, count_not_found = 0, 0
77
-
78
- LANGUAGE_TO_LANGUAGES = {
79
- "zh": ["zh", "en"],
80
- "jp": ["jp", "en"],
81
- "en": ["en"],
82
- }
83
-
84
- with open(path, "r", encoding="utf-8") as f:
85
- for line in f.readlines():
86
- splits = line.strip().split("|", maxsplit=3)
87
- if len(splits) != 4:
88
- logger.warning(f"Invalid line: {line}")
89
- continue
90
-
91
- filename, speaker, language, text = splits
92
- file = Path(filename)
93
- language = language.strip().lower()
94
-
95
- if language == "ja":
96
- language = "jp"
97
-
98
- assert language in ["zh", "jp", "en"], f"Invalid language {language}"
99
- languages = LANGUAGE_TO_LANGUAGES[language]
100
-
101
- if file in files:
102
- logger.warning(f"Duplicated file: {file}")
103
- count_duplicated += 1
104
- continue
105
-
106
- if not file.exists():
107
- logger.warning(f"File not found: {file}")
108
- count_not_found += 1
109
- continue
110
-
111
- results.append((file, speaker, languages, text))
112
-
113
- if count_duplicated > 0:
114
- logger.warning(f"Total duplicated files: {count_duplicated}")
115
-
116
- if count_not_found > 0:
117
- logger.warning(f"Total files not found: {count_not_found}")
118
-
119
- return results
 
1
  import os
 
2
  from pathlib import Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
 
5
  def get_latest_checkpoint(path: Path | str) -> Path | None:
 
14
  return None
15
 
16
  return ckpts[-1]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fish_speech/webui/css/style.css ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --my-200: #80eeee;
3
+ --my-50: #ecfdf5;
4
+ --water-width: 300px;
5
+ --water-heigh: 300px;
6
+ }
7
+
8
+
9
+ /* general styled components */
10
+ .tools {
11
+ align-items: center;
12
+ justify-content: center;
13
+ }
14
+
15
+ .gradio-button {
16
+ max-width: 2.2em;
17
+ min-width: 2.2em !important;
18
+ height: 2.4em;
19
+ align-self: end;
20
+ line-height: 1em;
21
+ border-radius: 0.5em;
22
+
23
+ }
24
+
25
+ .gradio-button.secondary-down, .gradio-button.secondary-down:hover{
26
+ box-shadow: 1px 1px 1px rgba(0,0,0,0.25) inset, 0px 0px 3px rgba(0,0,0,0.15) inset;
27
+ }
28
+
29
+ /* replace original footer with ours */
30
+ a{
31
+ font-weight: bold;
32
+ cursor: pointer;
33
+ color: #030C14 !important;
34
+ }
35
+
36
+ footer {
37
+ display: none !important;
38
+ }
39
+
40
+ #footer{
41
+ text-align: center;
42
+ }
43
+
44
+ #footer div{
45
+ display: inline-block;
46
+ }
47
+
48
+ #footer .versions{
49
+ font-size: 85%;
50
+ opacity: 0.85;
51
+ }
52
+
53
+ /*@keyframes moveBackground {*/
54
+ /* 0% {*/
55
+ /* background-position: 0 0;*/
56
+ /* }*/
57
+ /* 100% {*/
58
+ /* background-position: -100px 100px;*/
59
+ /* }*/
60
+ /*}*/
61
+ @keyframes moveJellyBackground {
62
+ 0% {
63
+ background-position: 0% 50%;
64
+ }
65
+ 50% {
66
+ background-position: 100% 50%;
67
+ }
68
+ 100% {
69
+ background-position: 0% 50%;
70
+ }
71
+ }
72
+
73
+ .gradio-container {
74
+ position: absolute;
75
+ z-index: 10;
76
+ }
77
+
78
+
79
+ .quan {
80
+ position: absolute;
81
+ bottom: 0;
82
+ width: var(--water-width);
83
+ height: var(--water-heigh);
84
+ border-radius: 0;
85
+ /*border: 3px solid rgb(246, 247, 248);*/
86
+ /*box-shadow: 0 0 0 3px rgb(41, 134, 196);*/
87
+ z-index: 0;
88
+
89
+ }
90
+
91
+ .quan:last-child {
92
+ margin-right: 0;
93
+ }
94
+
95
+ .shui {
96
+ position: absolute;
97
+ top: 0;
98
+ left: 0;
99
+ width: 100%;
100
+ height: 100%;
101
+ background-color: rgb(23, 106, 201);
102
+ border-radius: 0;
103
+ overflow: hidden;
104
+ z-index: 0;
105
+ }
106
+
107
+ .shui::after {
108
+
109
+ content: '';
110
+ position: absolute;
111
+ top: 20%;
112
+ left: 50%;
113
+ width: 150%;
114
+ height: 150%;
115
+ border-radius: 40%;
116
+ background-image: radial-gradient(circle at 0% 50%, #dcfcf1, var(--my-50) 50%);
117
+ animation: shi 5s linear infinite;
118
+ }
119
+
120
+ @keyframes shi {
121
+ 0% {
122
+ transform: translate(-50%, -65%) rotate(0deg);
123
+ }
124
+ 100% {
125
+ transform: translate(-50%, -65%) rotate(360deg);
126
+ }
127
+ }
128
+
129
+ .shui::before {
130
+ content: '';
131
+ position: absolute;
132
+ top: 20%;
133
+ left: 50%;
134
+ width: 150%;
135
+ height: 150%;
136
+ border-radius: 42%;
137
+ background-color: rgb(240, 228, 228, 0.2);
138
+ animation: xu 7s linear infinite;
139
+ }
140
+
141
+ @keyframes xu {
142
+ 0% {
143
+ transform: translate(-50%, -60%) rotate(0deg);
144
+ }
145
+ 100% {
146
+ transform: translate(-50%, -60%) rotate(360deg);
147
+ }
148
+ }
149
+
150
+ fieldset.data_src div.wrap label {
151
+ background: #f8bffee0 !important;
152
+ }
153
+
154
+ .scrollable-component {
155
+ max-height: 100px;
156
+ overflow-y: auto;
157
+ }
158
+
159
+ #file_accordion {
160
+ max-height: 220px !important;
161
+ }
fish_speech/webui/html/footer.html ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="color: rgba(25,255,205,0.7) !important;">
2
+ <a href="{api_docs}">API</a>
3
+  • 
4
+ <a href="https://github.com/fishaudio/fish-speech">Github</a>
5
+  • 
6
+ <a href="https://gradio.app">Gradio</a>
7
+ </div>
8
+ <br />
9
+ <div class="versions" style="color: rgba(25,255,205,0.7) !important;">
10
+ {versions}
11
+ </div>
fish_speech/webui/js/animate.js ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ function createGradioAnimation() {
3
+ const params = new URLSearchParams(window.location.search);
4
+ if (!params.has('__theme')) {
5
+ params.set('__theme', 'light');
6
+ window.location.search = params.toString();
7
+ }
8
+
9
+ var gradioApp = document.querySelector('gradio-app');
10
+ if (gradioApp) {
11
+
12
+ document.documentElement.style.setProperty('--my-200', '#80eeee');
13
+ document.documentElement.style.setProperty('--my-50', '#ecfdf5');
14
+
15
+ // gradioApp.style.position = 'relative';
16
+ // gradioApp.style.backgroundSize = '200% 200%';
17
+ // gradioApp.style.animation = 'moveJellyBackground 10s ease infinite';
18
+ // gradioApp.style.backgroundImage = 'radial-gradient(circle at 0% 50%, var(--my-200), var(--my-50) 50%)';
19
+ // gradioApp.style.display = 'flex';
20
+ // gradioApp.style.justifyContent = 'flex-start';
21
+ // gradioApp.style.flexWrap = 'nowrap';
22
+ // gradioApp.style.overflowX = 'auto';
23
+
24
+ // for (let i = 0; i < 6; i++) {
25
+ // var quan = document.createElement('div');
26
+ // quan.className = 'quan';
27
+ // gradioApp.insertBefore(quan, gradioApp.firstChild);
28
+ // quan.id = 'quan' + i.toString();
29
+ // quan.style.left = 'calc(var(--water-width) * ' + i.toString() + ')';
30
+ // var quanContainer = document.querySelector('.quan');
31
+ // if (quanContainer) {
32
+ // var shui = document.createElement('div');
33
+ // shui.className = 'shui';
34
+ // quanContainer.insertBefore(shui, quanContainer.firstChild)
35
+ // }
36
+ // }
37
+ }
38
+
39
+ var container = document.createElement('div');
40
+ container.id = 'gradio-animation';
41
+ container.style.fontSize = '2em';
42
+ container.style.fontFamily = 'Maiandra GD, ui-monospace, monospace';
43
+ container.style.fontWeight = 'bold';
44
+ container.style.textAlign = 'center';
45
+ container.style.marginBottom = '20px';
46
+
47
+ var text = 'Welcome to Fish-Speech!';
48
+ for (var i = 0; i < text.length; i++) {
49
+ (function(i){
50
+ setTimeout(function(){
51
+ var letter = document.createElement('span');
52
+ letter.style.opacity = '0';
53
+ letter.style.transition = 'opacity 0.5s';
54
+ letter.innerText = text[i];
55
+
56
+ container.appendChild(letter);
57
+
58
+ setTimeout(function() {
59
+ letter.style.opacity = '1';
60
+ }, 50);
61
+ }, i * 200);
62
+ })(i);
63
+ }
64
+
65
+ var gradioContainer = document.querySelector('.gradio-container');
66
+ gradioContainer.insertBefore(container, gradioContainer.firstChild);
67
+
68
+ return 'Animation created';
69
+ }
fish_speech/webui/launch_utils.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from functools import lru_cache
6
+ from pathlib import Path
7
+ from typing import Iterable
8
+
9
+ import gradio as gr
10
+ from gradio.themes.base import Base
11
+ from gradio.themes.utils import colors, fonts, sizes
12
+
13
+ GIT = (
14
+ (Path(os.environ.get("GIT_HOME", "")) / "git").resolve()
15
+ if sys.platform == "win32"
16
+ else "git"
17
+ )
18
+ GIT = str(GIT)
19
+
20
+
21
+ def is_module_installed(module_name: str) -> bool:
22
+ spec = importlib.util.find_spec(module_name)
23
+ return spec is not None
24
+
25
+
26
+ @lru_cache()
27
+ def commit_hash():
28
+ try:
29
+ return subprocess.check_output(
30
+ [GIT, "log", "-1", "--format='%h %s'"], shell=False, encoding="utf8"
31
+ ).strip()
32
+ except Exception:
33
+ return "<none>"
34
+
35
+
36
+ def versions_html():
37
+ import torch
38
+
39
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
40
+ commit = commit_hash()
41
+ hash = commit.strip("'").split(" ")[0]
42
+
43
+ return f"""
44
+ version: <a href="https://github.com/fishaudio/fish-speech/commit/{hash}">{hash}</a>
45
+ &#x2000;•&#x2000;
46
+ python: <span title="{sys.version}">{python_version}</span>
47
+ &#x2000;•&#x2000;
48
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
49
+ &#x2000;•&#x2000;
50
+ gradio: {gr.__version__}
51
+ &#x2000;•&#x2000;
52
+ author: <a href="https://github.com/fishaudio">fishaudio</a>
53
+ """
54
+
55
+
56
+ def version_check(commit):
57
+ try:
58
+ import requests
59
+
60
+ commits = requests.get(
61
+ "https://api.github.com/repos/fishaudio/fish-speech/branches/main"
62
+ ).json()
63
+ if commit != "<none>" and commits["commit"]["sha"] != commit:
64
+ print("--------------------------------------------------------")
65
+ print("| You are not up to date with the most recent release. |")
66
+ print("| Consider running `git pull` to update. |")
67
+ print("--------------------------------------------------------")
68
+ elif commits["commit"]["sha"] == commit:
69
+ print("You are up to date with the most recent release.")
70
+ else:
71
+ print("Not a git clone, can't perform version check.")
72
+ except Exception as e:
73
+ print("version check failed", e)
74
+
75
+
76
+ class Seafoam(Base):
77
+ def __init__(
78
+ self,
79
+ *,
80
+ primary_hue: colors.Color | str = colors.emerald,
81
+ secondary_hue: colors.Color | str = colors.blue,
82
+ neutral_hue: colors.Color | str = colors.blue,
83
+ spacing_size: sizes.Size | str = sizes.spacing_md,
84
+ radius_size: sizes.Size | str = sizes.radius_md,
85
+ text_size: sizes.Size | str = sizes.text_lg,
86
+ font: fonts.Font | str | Iterable[fonts.Font | str] = (
87
+ fonts.GoogleFont("Quicksand"),
88
+ "ui-sans-serif",
89
+ "sans-serif",
90
+ ),
91
+ font_mono: fonts.Font | str | Iterable[fonts.Font | str] = (
92
+ fonts.GoogleFont("IBM Plex Mono"),
93
+ "ui-monospace",
94
+ "monospace",
95
+ ),
96
+ ):
97
+ super().__init__(
98
+ primary_hue=primary_hue,
99
+ secondary_hue=secondary_hue,
100
+ neutral_hue=neutral_hue,
101
+ spacing_size=spacing_size,
102
+ radius_size=radius_size,
103
+ text_size=text_size,
104
+ font=font,
105
+ font_mono=font_mono,
106
+ )
107
+ super().set(
108
+ button_primary_background_fill="linear-gradient(90deg, *primary_300, *secondary_400)",
109
+ button_primary_background_fill_hover="linear-gradient(90deg, *primary_200, *secondary_300)",
110
+ button_primary_text_color="white",
111
+ button_primary_background_fill_dark="linear-gradient(90deg, *primary_600, *secondary_800)",
112
+ slider_color="*secondary_300",
113
+ slider_color_dark="*secondary_600",
114
+ block_title_text_weight="600",
115
+ block_border_width="3px",
116
+ block_shadow="*shadow_drop_lg",
117
+ button_shadow="*shadow_drop_lg",
118
+ button_small_padding="0px",
119
+ button_large_padding="3px",
120
+ )
fish_speech/webui/manage.py ADDED
@@ -0,0 +1,1237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import html
5
+ import json
6
+ import os
7
+ import platform
8
+ import shutil
9
+ import signal
10
+ import subprocess
11
+ import sys
12
+ from pathlib import Path
13
+
14
+ import gradio as gr
15
+ import psutil
16
+ import yaml
17
+ from loguru import logger
18
+ from tqdm import tqdm
19
+
20
+ PYTHON = os.path.join(os.environ.get("PYTHON_FOLDERPATH", ""), "python")
21
+ sys.path.insert(0, "")
22
+ print(sys.path)
23
+ cur_work_dir = Path(os.getcwd()).resolve()
24
+ print("You are in ", str(cur_work_dir))
25
+
26
+ from fish_speech.i18n import i18n
27
+ from fish_speech.webui.launch_utils import Seafoam, is_module_installed, versions_html
28
+
29
+ config_path = cur_work_dir / "fish_speech" / "configs"
30
+ vqgan_yml_path = config_path / "firefly_gan_vq.yaml"
31
+ llama_yml_path = config_path / "text2semantic_finetune.yaml"
32
+
33
+ env = os.environ.copy()
34
+ env["no_proxy"] = "127.0.0.1, localhost, 0.0.0.0"
35
+
36
+ seafoam = Seafoam()
37
+
38
+
39
+ def build_html_error_message(error):
40
+ return f"""
41
+ <div style="color: red; font-weight: bold;">
42
+ {html.escape(error)}
43
+ </div>
44
+ """
45
+
46
+
47
+ def build_html_ok_message(msg):
48
+ return f"""
49
+ <div style="color: green; font-weight: bold;">
50
+ {html.escape(msg)}
51
+ </div>
52
+ """
53
+
54
+
55
+ def build_html_href(link, desc, msg):
56
+ return f"""
57
+ <span style="color: green; font-weight: bold; display: inline-block">
58
+ {html.escape(msg)}
59
+ <a href="{link}">{desc}</a>
60
+ </span>
61
+ """
62
+
63
+
64
+ def load_data_in_raw(path):
65
+ with open(path, "r", encoding="utf-8") as file:
66
+ data = file.read()
67
+ return str(data)
68
+
69
+
70
+ def kill_proc_tree(pid, including_parent=True):
71
+ try:
72
+ parent = psutil.Process(pid)
73
+ except psutil.NoSuchProcess:
74
+ # Process already terminated
75
+ return
76
+
77
+ children = parent.children(recursive=True)
78
+ for child in children:
79
+ try:
80
+ os.kill(child.pid, signal.SIGTERM) # or signal.SIGKILL
81
+ except OSError:
82
+ pass
83
+ if including_parent:
84
+ try:
85
+ os.kill(parent.pid, signal.SIGTERM) # or signal.SIGKILL
86
+ except OSError:
87
+ pass
88
+
89
+
90
+ system = platform.system()
91
+ p_label = None
92
+ p_infer = None
93
+ p_tensorboard = None
94
+
95
+
96
+ def kill_process(pid):
97
+ if system == "Windows":
98
+ cmd = "taskkill /t /f /pid %s" % pid
99
+ # os.system(cmd)
100
+ subprocess.run(cmd)
101
+ else:
102
+ kill_proc_tree(pid)
103
+
104
+
105
+ def change_label(if_label):
106
+ global p_label
107
+ if if_label == True and p_label is None:
108
+ url = "http://localhost:3000"
109
+ remote_url = "https://text-labeler.pages.dev/"
110
+ try:
111
+ p_label = subprocess.Popen(
112
+ [
113
+ (
114
+ "asr-label-linux-x64"
115
+ if sys.platform == "linux"
116
+ else "asr-label-win-x64.exe"
117
+ )
118
+ ]
119
+ )
120
+ except FileNotFoundError:
121
+ logger.warning("asr-label execution not found!")
122
+
123
+ yield build_html_href(
124
+ link=remote_url,
125
+ desc=i18n("Optional online ver"),
126
+ msg=i18n("Opened labeler in browser"),
127
+ )
128
+
129
+ elif if_label == False and p_label is not None:
130
+ kill_process(p_label.pid)
131
+ p_label = None
132
+ yield build_html_ok_message("Nothing")
133
+
134
+
135
+ def clean_infer_cache():
136
+ import tempfile
137
+
138
+ temp_dir = Path(tempfile.gettempdir())
139
+ gradio_dir = str(temp_dir / "gradio")
140
+ try:
141
+ shutil.rmtree(gradio_dir)
142
+ logger.info(f"Deleted cached audios: {gradio_dir}")
143
+ except PermissionError:
144
+ logger.info(f"Permission denied: Unable to delete {gradio_dir}")
145
+ except FileNotFoundError:
146
+ logger.info(f"{gradio_dir} was not found")
147
+ except Exception as e:
148
+ logger.info(f"An error occurred: {e}")
149
+
150
+
151
+ def change_infer(
152
+ if_infer,
153
+ host,
154
+ port,
155
+ infer_decoder_model,
156
+ infer_decoder_config,
157
+ infer_llama_model,
158
+ infer_compile,
159
+ ):
160
+ global p_infer
161
+ if if_infer == True and p_infer == None:
162
+ env = os.environ.copy()
163
+
164
+ env["GRADIO_SERVER_NAME"] = host
165
+ env["GRADIO_SERVER_PORT"] = port
166
+ # 启动第二个进程
167
+ url = f"http://{host}:{port}"
168
+ yield build_html_ok_message(
169
+ i18n("Inferring interface is launched at {}").format(url)
170
+ )
171
+
172
+ clean_infer_cache()
173
+
174
+ p_infer = subprocess.Popen(
175
+ [
176
+ PYTHON,
177
+ "tools/webui.py",
178
+ "--decoder-checkpoint-path",
179
+ infer_decoder_model,
180
+ "--decoder-config-name",
181
+ infer_decoder_config,
182
+ "--llama-checkpoint-path",
183
+ infer_llama_model,
184
+ ]
185
+ + (["--compile"] if infer_compile == "Yes" else []),
186
+ env=env,
187
+ )
188
+
189
+ elif if_infer == False and p_infer is not None:
190
+ kill_process(p_infer.pid)
191
+ p_infer = None
192
+ yield build_html_error_message(i18n("Infer interface is closed"))
193
+
194
+
195
+ js = load_data_in_raw("fish_speech/webui/js/animate.js")
196
+ css = load_data_in_raw("fish_speech/webui/css/style.css")
197
+
198
+ data_pre_output = (cur_work_dir / "data").resolve()
199
+ default_model_output = (cur_work_dir / "results").resolve()
200
+ default_filelist = data_pre_output / "detect.list"
201
+ data_pre_output.mkdir(parents=True, exist_ok=True)
202
+
203
+ items = []
204
+ dict_items = {}
205
+
206
+
207
+ def load_yaml_data_in_fact(yml_path):
208
+ with open(yml_path, "r", encoding="utf-8") as file:
209
+ yml = yaml.safe_load(file)
210
+ return yml
211
+
212
+
213
+ def write_yaml_data_in_fact(yml, yml_path):
214
+ with open(yml_path, "w", encoding="utf-8") as file:
215
+ yaml.safe_dump(yml, file, allow_unicode=True)
216
+ return yml
217
+
218
+
219
+ def generate_tree(directory, depth=0, max_depth=None, prefix=""):
220
+ if max_depth is not None and depth > max_depth:
221
+ return ""
222
+
223
+ tree_str = ""
224
+ files = []
225
+ directories = []
226
+ for item in os.listdir(directory):
227
+ if os.path.isdir(os.path.join(directory, item)):
228
+ directories.append(item)
229
+ else:
230
+ files.append(item)
231
+
232
+ entries = directories + files
233
+ for i, entry in enumerate(entries):
234
+ connector = "├── " if i < len(entries) - 1 else "└── "
235
+ tree_str += f"{prefix}{connector}{entry}<br />"
236
+ if i < len(directories):
237
+ extension = "│ " if i < len(entries) - 1 else " "
238
+ tree_str += generate_tree(
239
+ os.path.join(directory, entry),
240
+ depth + 1,
241
+ max_depth,
242
+ prefix=prefix + extension,
243
+ )
244
+ return tree_str
245
+
246
+
247
+ def new_explorer(data_path, max_depth):
248
+ return gr.Markdown(
249
+ elem_classes=["scrollable-component"],
250
+ value=generate_tree(data_path, max_depth=max_depth),
251
+ )
252
+
253
+
254
+ def add_item(
255
+ folder: str,
256
+ method: str,
257
+ label_lang: str,
258
+ if_initial_prompt: bool,
259
+ initial_prompt: str | None,
260
+ ):
261
+ folder = folder.strip(" ").strip('"')
262
+
263
+ folder_path = Path(folder)
264
+
265
+ if folder and folder not in items and data_pre_output not in folder_path.parents:
266
+ if folder_path.is_dir():
267
+ items.append(folder)
268
+ dict_items[folder] = dict(
269
+ type="folder",
270
+ method=method,
271
+ label_lang=label_lang,
272
+ initial_prompt=initial_prompt if if_initial_prompt else None,
273
+ )
274
+ elif folder:
275
+ err = folder
276
+ return gr.Checkboxgroup(choices=items), build_html_error_message(
277
+ i18n("Invalid path: {}").format(err)
278
+ )
279
+
280
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
281
+ logger.info("After Adding: " + formatted_data)
282
+ gr.Info(formatted_data)
283
+ return gr.Checkboxgroup(choices=items), build_html_ok_message(
284
+ i18n("Added path successfully!")
285
+ )
286
+
287
+
288
+ def remove_items(selected_items):
289
+ global items, dict_items
290
+ to_remove = [item for item in items if item in selected_items]
291
+ for item in to_remove:
292
+ del dict_items[item]
293
+ items = [item for item in items if item in dict_items.keys()]
294
+ formatted_data = json.dumps(dict_items, ensure_ascii=False, indent=4)
295
+ logger.info(formatted_data)
296
+ gr.Warning("After Removing: " + formatted_data)
297
+ return gr.Checkboxgroup(choices=items, value=[]), build_html_ok_message(
298
+ i18n("Removed path successfully!")
299
+ )
300
+
301
+
302
+ def show_selected(options):
303
+ selected_options = ", ".join(options)
304
+
305
+ if options:
306
+ return i18n("Selected: {}").format(selected_options)
307
+ else:
308
+ return i18n("No selected options")
309
+
310
+
311
+ from pydub import AudioSegment
312
+
313
+
314
+ def convert_to_mono_in_place(audio_path: Path):
315
+ audio = AudioSegment.from_file(audio_path)
316
+ if audio.channels > 1:
317
+ mono_audio = audio.set_channels(1)
318
+ mono_audio.export(audio_path, format=audio_path.suffix[1:])
319
+ logger.info(f"Convert {audio_path} successfully")
320
+
321
+
322
+ def list_copy(list_file_path, method):
323
+ wav_root = data_pre_output
324
+ lst = []
325
+ with list_file_path.open("r", encoding="utf-8") as file:
326
+ for line in tqdm(file, desc="Processing audio/transcript"):
327
+ wav_path, speaker_name, language, text = line.strip().split("|")
328
+ original_wav_path = Path(wav_path)
329
+ target_wav_path = (
330
+ wav_root / original_wav_path.parent.name / original_wav_path.name
331
+ )
332
+ lst.append(f"{target_wav_path}|{speaker_name}|{language}|{text}")
333
+ if target_wav_path.is_file():
334
+ continue
335
+ target_wav_path.parent.mkdir(parents=True, exist_ok=True)
336
+ if method == i18n("Copy"):
337
+ shutil.copy(original_wav_path, target_wav_path)
338
+ else:
339
+ shutil.move(original_wav_path, target_wav_path.parent)
340
+ convert_to_mono_in_place(target_wav_path)
341
+ original_lab_path = original_wav_path.with_suffix(".lab")
342
+ target_lab_path = (
343
+ wav_root
344
+ / original_wav_path.parent.name
345
+ / original_wav_path.with_suffix(".lab").name
346
+ )
347
+ if target_lab_path.is_file():
348
+ continue
349
+ if method == i18n("Copy"):
350
+ shutil.copy(original_lab_path, target_lab_path)
351
+ else:
352
+ shutil.move(original_lab_path, target_lab_path.parent)
353
+
354
+ if method == i18n("Move"):
355
+ with list_file_path.open("w", encoding="utf-8") as file:
356
+ file.writelines("\n".join(lst))
357
+
358
+ del lst
359
+ return build_html_ok_message(i18n("Use filelist"))
360
+
361
+
362
+ def check_files(data_path: str, max_depth: int, label_model: str, label_device: str):
363
+ global dict_items
364
+ data_path = Path(data_path)
365
+ gr.Warning("Pre-processing begins...")
366
+ for item, content in dict_items.items():
367
+ item_path = Path(item)
368
+ tar_path = data_path / item_path.name
369
+
370
+ if content["type"] == "folder" and item_path.is_dir():
371
+ if content["method"] == i18n("Copy"):
372
+ os.makedirs(tar_path, exist_ok=True)
373
+ shutil.copytree(
374
+ src=str(item_path), dst=str(tar_path), dirs_exist_ok=True
375
+ )
376
+ elif not tar_path.is_dir():
377
+ shutil.move(src=str(item_path), dst=str(tar_path))
378
+
379
+ for suf in ["wav", "flac", "mp3"]:
380
+ for audio_path in tar_path.glob(f"**/*.{suf}"):
381
+ convert_to_mono_in_place(audio_path)
382
+
383
+ cur_lang = content["label_lang"]
384
+ initial_prompt = content["initial_prompt"]
385
+
386
+ transcribe_cmd = [
387
+ PYTHON,
388
+ "tools/whisper_asr.py",
389
+ "--model-size",
390
+ label_model,
391
+ "--device",
392
+ label_device,
393
+ "--audio-dir",
394
+ tar_path,
395
+ "--save-dir",
396
+ tar_path,
397
+ "--language",
398
+ cur_lang,
399
+ ]
400
+
401
+ if initial_prompt is not None:
402
+ transcribe_cmd += ["--initial-prompt", initial_prompt]
403
+
404
+ if cur_lang != "IGNORE":
405
+ try:
406
+ gr.Warning("Begin To Transcribe")
407
+ subprocess.run(
408
+ transcribe_cmd,
409
+ env=env,
410
+ )
411
+ except Exception:
412
+ print("Transcription error occurred")
413
+
414
+ elif content["type"] == "file" and item_path.is_file():
415
+ list_copy(item_path, content["method"])
416
+
417
+ return build_html_ok_message(i18n("Move files successfully")), new_explorer(
418
+ data_path, max_depth=max_depth
419
+ )
420
+
421
+
422
+ def generate_folder_name():
423
+ now = datetime.datetime.now()
424
+ folder_name = now.strftime("%Y%m%d_%H%M%S")
425
+ return folder_name
426
+
427
+
428
+ def train_process(
429
+ data_path: str,
430
+ option: str,
431
+ # llama config
432
+ llama_ckpt,
433
+ llama_base_config,
434
+ llama_lr,
435
+ llama_maxsteps,
436
+ llama_data_num_workers,
437
+ llama_data_batch_size,
438
+ llama_data_max_length,
439
+ llama_precision,
440
+ llama_check_interval,
441
+ llama_grad_batches,
442
+ llama_use_speaker,
443
+ llama_use_lora,
444
+ ):
445
+
446
+ backend = "nccl" if sys.platform == "linux" else "gloo"
447
+
448
+ new_project = generate_folder_name()
449
+ print("New Project Name: ", new_project)
450
+
451
+ if option == "VQGAN":
452
+ msg = "Skipped VQGAN Training."
453
+ gr.Warning(msg)
454
+ logger.info(msg)
455
+
456
+ if option == "LLAMA":
457
+ msg = "LLAMA Training begins..."
458
+ gr.Warning(msg)
459
+ logger.info(msg)
460
+ subprocess.run(
461
+ [
462
+ PYTHON,
463
+ "tools/vqgan/extract_vq.py",
464
+ str(data_pre_output),
465
+ "--num-workers",
466
+ "1",
467
+ "--batch-size",
468
+ "16",
469
+ "--config-name",
470
+ "firefly_gan_vq",
471
+ "--checkpoint-path",
472
+ "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
473
+ ]
474
+ )
475
+
476
+ subprocess.run(
477
+ [
478
+ PYTHON,
479
+ "tools/llama/build_dataset.py",
480
+ "--input",
481
+ str(data_pre_output),
482
+ "--text-extension",
483
+ ".lab",
484
+ "--num-workers",
485
+ "16",
486
+ ]
487
+ )
488
+ ckpt_path = "checkpoints/fish-speech-1.4/model.pth"
489
+ lora_prefix = "lora_" if llama_use_lora else ""
490
+ llama_name = lora_prefix + "text2semantic_" + new_project
491
+ latest = next(
492
+ iter(
493
+ sorted(
494
+ [
495
+ str(p.relative_to("results"))
496
+ for p in Path("results").glob(lora_prefix + "text2sem*/")
497
+ ],
498
+ reverse=True,
499
+ )
500
+ ),
501
+ llama_name,
502
+ )
503
+ project = (
504
+ llama_name
505
+ if llama_ckpt == i18n("new")
506
+ else (
507
+ latest
508
+ if llama_ckpt == i18n("latest")
509
+ else Path(llama_ckpt).relative_to("results")
510
+ )
511
+ )
512
+ logger.info(project)
513
+
514
+ if llama_check_interval > llama_maxsteps:
515
+ llama_check_interval = llama_maxsteps
516
+
517
+ train_cmd = [
518
+ PYTHON,
519
+ "fish_speech/train.py",
520
+ "--config-name",
521
+ "text2semantic_finetune",
522
+ f"project={project}",
523
+ f"trainer.strategy.process_group_backend={backend}",
524
+ f"train_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
525
+ f"val_dataset.proto_files={str(['data/quantized-dataset-ft'])}",
526
+ f"model.optimizer.lr={llama_lr}",
527
+ f"trainer.max_steps={llama_maxsteps}",
528
+ f"data.num_workers={llama_data_num_workers}",
529
+ f"data.batch_size={llama_data_batch_size}",
530
+ f"max_length={llama_data_max_length}",
531
+ f"trainer.precision={llama_precision}",
532
+ f"trainer.val_check_interval={llama_check_interval}",
533
+ f"trainer.accumulate_grad_batches={llama_grad_batches}",
534
+ f"train_dataset.interactive_prob={llama_use_speaker}",
535
+ ] + ([f"[email protected]_config=r_8_alpha_16"] if llama_use_lora else [])
536
+ logger.info(train_cmd)
537
+ subprocess.run(train_cmd)
538
+
539
+ return build_html_ok_message(i18n("Training stopped"))
540
+
541
+
542
+ def tensorboard_process(
543
+ if_tensorboard: bool,
544
+ tensorboard_dir: str,
545
+ host: str,
546
+ port: str,
547
+ ):
548
+ global p_tensorboard
549
+ if if_tensorboard == True and p_tensorboard == None:
550
+ url = f"http://{host}:{port}"
551
+ yield build_html_ok_message(
552
+ i18n("Tensorboard interface is launched at {}").format(url)
553
+ )
554
+ prefix = ["tensorboard"]
555
+ if Path("fishenv").exists():
556
+ prefix = ["fishenv/env/python.exe", "fishenv/env/Scripts/tensorboard.exe"]
557
+
558
+ p_tensorboard = subprocess.Popen(
559
+ prefix
560
+ + [
561
+ "--logdir",
562
+ tensorboard_dir,
563
+ "--host",
564
+ host,
565
+ "--port",
566
+ port,
567
+ "--reload_interval",
568
+ "120",
569
+ ]
570
+ )
571
+ elif if_tensorboard == False and p_tensorboard != None:
572
+ kill_process(p_tensorboard.pid)
573
+ p_tensorboard = None
574
+ yield build_html_error_message(i18n("Tensorboard interface is closed"))
575
+
576
+
577
+ def fresh_tb_dir():
578
+ return gr.Dropdown(
579
+ choices=[str(p) for p in Path("results").glob("**/tensorboard/")]
580
+ )
581
+
582
+
583
+ def list_decoder_models():
584
+ paths = [str(p) for p in Path("checkpoints").glob("fish*/firefly*.pth")]
585
+ if not paths:
586
+ logger.warning("No decoder model found")
587
+ return paths
588
+
589
+
590
+ def list_llama_models():
591
+ choices = [str(p.parent) for p in Path("checkpoints").glob("merged*/*model*.pth")]
592
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fish*/*model*.pth")]
593
+ choices += [str(p.parent) for p in Path("checkpoints").glob("fs*/*model*.pth")]
594
+ choices = sorted(choices, reverse=True)
595
+ if not choices:
596
+ logger.warning("No LLaMA model found")
597
+ return choices
598
+
599
+
600
+ def list_lora_llama_models():
601
+ choices = sorted(
602
+ [str(p) for p in Path("results").glob("lora*/**/*.ckpt")], reverse=True
603
+ )
604
+ if not choices:
605
+ logger.warning("No LoRA LLaMA model found")
606
+ return choices
607
+
608
+
609
+ def fresh_decoder_model():
610
+ return gr.Dropdown(choices=list_decoder_models())
611
+
612
+
613
+ def fresh_llama_ckpt(llama_use_lora):
614
+ return gr.Dropdown(
615
+ choices=[i18n("latest"), i18n("new")]
616
+ + (
617
+ [str(p) for p in Path("results").glob("text2sem*/")]
618
+ if not llama_use_lora
619
+ else [str(p) for p in Path("results").glob("lora_*/")]
620
+ )
621
+ )
622
+
623
+
624
+ def fresh_llama_model():
625
+ return gr.Dropdown(choices=list_llama_models())
626
+
627
+
628
+ def llama_lora_merge(llama_weight, lora_llama_config, lora_weight, llama_lora_output):
629
+ if (
630
+ lora_weight is None
631
+ or not Path(lora_weight).exists()
632
+ or not Path(llama_weight).exists()
633
+ ):
634
+ return build_html_error_message(
635
+ i18n(
636
+ "Path error, please check the model file exists in the corresponding path"
637
+ )
638
+ )
639
+ gr.Warning("Merging begins...")
640
+ merge_cmd = [
641
+ PYTHON,
642
+ "tools/llama/merge_lora.py",
643
+ "--lora-config",
644
+ "r_8_alpha_16",
645
+ "--lora-weight",
646
+ lora_weight,
647
+ "--output",
648
+ llama_lora_output + "_" + generate_folder_name(),
649
+ ]
650
+ logger.info(merge_cmd)
651
+ subprocess.run(merge_cmd)
652
+ return build_html_ok_message(i18n("Merge successfully"))
653
+
654
+
655
+ def llama_quantify(llama_weight, quantify_mode):
656
+ if llama_weight is None or not Path(llama_weight).exists():
657
+ return build_html_error_message(
658
+ i18n(
659
+ "Path error, please check the model file exists in the corresponding path"
660
+ )
661
+ )
662
+
663
+ gr.Warning("Quantifying begins...")
664
+
665
+ now = generate_folder_name()
666
+ quantify_cmd = [
667
+ PYTHON,
668
+ "tools/llama/quantize.py",
669
+ "--checkpoint-path",
670
+ llama_weight,
671
+ "--mode",
672
+ quantify_mode,
673
+ "--timestamp",
674
+ now,
675
+ ]
676
+ logger.info(quantify_cmd)
677
+ subprocess.run(quantify_cmd)
678
+ if quantify_mode == "int8":
679
+ quantize_path = str(
680
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-{now}"
681
+ )
682
+ else:
683
+ quantize_path = str(
684
+ Path(os.getcwd()) / "checkpoints" / f"fs-1.2-{quantify_mode}-g128-{now}"
685
+ )
686
+ return build_html_ok_message(
687
+ i18n("Quantify successfully") + f"Path: {quantize_path}"
688
+ )
689
+
690
+
691
+ init_vqgan_yml = load_yaml_data_in_fact(vqgan_yml_path)
692
+ init_llama_yml = load_yaml_data_in_fact(llama_yml_path)
693
+
694
+ with gr.Blocks(
695
+ head="<style>\n" + css + "\n</style>",
696
+ js=js,
697
+ theme=seafoam,
698
+ analytics_enabled=False,
699
+ title="Fish Speech",
700
+ ) as demo:
701
+ with gr.Row():
702
+ with gr.Column():
703
+ with gr.Tab("\U0001F4D6 " + i18n("Data Preprocessing")):
704
+ with gr.Row():
705
+ textbox = gr.Textbox(
706
+ label="\U0000270F "
707
+ + i18n("Input Audio & Source Path for Transcription"),
708
+ info=i18n("Speaker is identified by the folder name"),
709
+ interactive=True,
710
+ )
711
+ with gr.Row(equal_height=False):
712
+ with gr.Column():
713
+ output_radio = gr.Radio(
714
+ label="\U0001F4C1 "
715
+ + i18n("Select source file processing method"),
716
+ choices=[i18n("Copy"), i18n("Move")],
717
+ value=i18n("Copy"),
718
+ interactive=True,
719
+ )
720
+ with gr.Column():
721
+ error = gr.HTML(label=i18n("Error Message"))
722
+ if_label = gr.Checkbox(
723
+ label=i18n("Open Labeler WebUI"), scale=0, show_label=True
724
+ )
725
+
726
+ with gr.Row():
727
+ label_device = gr.Dropdown(
728
+ label=i18n("Labeling Device"),
729
+ info=i18n(
730
+ "It is recommended to use CUDA, if you have low configuration, use CPU"
731
+ ),
732
+ choices=["cpu", "cuda"],
733
+ value="cuda",
734
+ interactive=True,
735
+ )
736
+ label_model = gr.Dropdown(
737
+ label=i18n("Whisper Model"),
738
+ info=i18n("Faster Whisper, Up to 5g GPU memory usage"),
739
+ choices=["large-v3", "medium"],
740
+ value="large-v3",
741
+ interactive=True,
742
+ )
743
+ label_radio = gr.Dropdown(
744
+ label=i18n("Optional Label Language"),
745
+ info=i18n(
746
+ "If there is no corresponding text for the audio, apply ASR for assistance, support .txt or .lab format"
747
+ ),
748
+ choices=[
749
+ (i18n("Chinese"), "zh"),
750
+ (i18n("English"), "en"),
751
+ (i18n("Japanese"), "ja"),
752
+ (i18n("Disabled"), "IGNORE"),
753
+ (i18n("auto"), "auto"),
754
+ ],
755
+ value="IGNORE",
756
+ interactive=True,
757
+ )
758
+
759
+ with gr.Row():
760
+ if_initial_prompt = gr.Checkbox(
761
+ value=False,
762
+ label=i18n("Enable Initial Prompt"),
763
+ min_width=120,
764
+ scale=0,
765
+ )
766
+ initial_prompt = gr.Textbox(
767
+ label=i18n("Initial Prompt"),
768
+ info=i18n(
769
+ "Initial prompt can provide contextual or vocabulary-specific guidance to the model."
770
+ ),
771
+ placeholder="This audio introduces the basic concepts and applications of artificial intelligence and machine learning.",
772
+ interactive=False,
773
+ )
774
+
775
+ with gr.Row():
776
+ add_button = gr.Button(
777
+ "\U000027A1 " + i18n("Add to Processing Area"),
778
+ variant="primary",
779
+ )
780
+ remove_button = gr.Button(
781
+ "\U000026D4 " + i18n("Remove Selected Data")
782
+ )
783
+
784
+ with gr.Tab("\U0001F6E0 " + i18n("Training Configuration")):
785
+ with gr.Row():
786
+ model_type_radio = gr.Radio(
787
+ label=i18n(
788
+ "Select the model to be trained (Depending on the Tab page you are on)"
789
+ ),
790
+ interactive=False,
791
+ choices=["VQGAN", "LLAMA"],
792
+ value="VQGAN",
793
+ )
794
+ with gr.Row():
795
+ with gr.Tabs():
796
+ with gr.Tab(label=i18n("VQGAN Configuration")) as vqgan_page:
797
+ gr.HTML("You don't need to train this model!")
798
+
799
+ with gr.Tab(label=i18n("LLAMA Configuration")) as llama_page:
800
+ with gr.Row(equal_height=False):
801
+ llama_use_lora = gr.Checkbox(
802
+ label=i18n("Use LoRA"),
803
+ info=i18n(
804
+ "Use LoRA can save GPU memory, but may reduce the quality of the model"
805
+ ),
806
+ value=True,
807
+ interactive=True,
808
+ )
809
+ llama_ckpt = gr.Dropdown(
810
+ label=i18n("Select LLAMA ckpt"),
811
+ choices=[i18n("latest"), i18n("new")]
812
+ + [
813
+ str(p)
814
+ for p in Path("results").glob("text2sem*/")
815
+ ]
816
+ + [str(p) for p in Path("results").glob("lora*/")],
817
+ value=i18n("latest"),
818
+ interactive=True,
819
+ )
820
+ with gr.Row(equal_height=False):
821
+ llama_lr_slider = gr.Slider(
822
+ label=i18n("Initial Learning Rate"),
823
+ info=i18n(
824
+ "lr smaller -> usually train slower but more stable"
825
+ ),
826
+ interactive=True,
827
+ minimum=1e-5,
828
+ maximum=1e-4,
829
+ step=1e-5,
830
+ value=5e-5,
831
+ )
832
+ llama_maxsteps_slider = gr.Slider(
833
+ label=i18n("Maximum Training Steps"),
834
+ info=i18n(
835
+ "recommend: max_steps = num_audios // batch_size * (2 to 5)"
836
+ ),
837
+ interactive=True,
838
+ minimum=1,
839
+ maximum=10000,
840
+ step=1,
841
+ value=50,
842
+ )
843
+ with gr.Row(equal_height=False):
844
+ llama_base_config = gr.Dropdown(
845
+ label=i18n("Model Size"),
846
+ choices=[
847
+ "text2semantic_finetune",
848
+ ],
849
+ value="text2semantic_finetune",
850
+ )
851
+ llama_data_num_workers_slider = gr.Slider(
852
+ label=i18n("Number of Workers"),
853
+ minimum=1,
854
+ maximum=32,
855
+ step=1,
856
+ value=4,
857
+ )
858
+ with gr.Row(equal_height=False):
859
+ llama_data_batch_size_slider = gr.Slider(
860
+ label=i18n("Batch Size"),
861
+ interactive=True,
862
+ minimum=1,
863
+ maximum=32,
864
+ step=1,
865
+ value=4,
866
+ )
867
+ llama_data_max_length_slider = gr.Slider(
868
+ label=i18n("Maximum Length per Sample"),
869
+ interactive=True,
870
+ minimum=1024,
871
+ maximum=4096,
872
+ step=128,
873
+ value=1024,
874
+ )
875
+ with gr.Row(equal_height=False):
876
+ llama_precision_dropdown = gr.Dropdown(
877
+ label=i18n("Precision"),
878
+ info=i18n(
879
+ "bf16-true is recommended for 30+ series GPU, 16-mixed is recommended for 10+ series GPU"
880
+ ),
881
+ interactive=True,
882
+ choices=["32", "bf16-true", "16-mixed"],
883
+ value="bf16-true",
884
+ )
885
+ llama_check_interval_slider = gr.Slider(
886
+ label=i18n("Save model every n steps"),
887
+ info=i18n(
888
+ "make sure that it's not greater than max_steps"
889
+ ),
890
+ interactive=True,
891
+ minimum=1,
892
+ maximum=1000,
893
+ step=1,
894
+ value=50,
895
+ )
896
+ with gr.Row(equal_height=False):
897
+ llama_grad_batches = gr.Slider(
898
+ label=i18n("Accumulate Gradient Batches"),
899
+ interactive=True,
900
+ minimum=1,
901
+ maximum=20,
902
+ step=1,
903
+ value=init_llama_yml["trainer"][
904
+ "accumulate_grad_batches"
905
+ ],
906
+ )
907
+ llama_use_speaker = gr.Slider(
908
+ label=i18n(
909
+ "Probability of applying Speaker Condition"
910
+ ),
911
+ interactive=True,
912
+ minimum=0.1,
913
+ maximum=1.0,
914
+ step=0.05,
915
+ value=init_llama_yml["train_dataset"][
916
+ "interactive_prob"
917
+ ],
918
+ )
919
+
920
+ with gr.Tab(label=i18n("Merge LoRA"), id=4):
921
+ with gr.Row(equal_height=False):
922
+ llama_weight = gr.Dropdown(
923
+ label=i18n("Base LLAMA Model"),
924
+ info=i18n(
925
+ "Type the path or select from the dropdown"
926
+ ),
927
+ choices=[
928
+ "checkpoints/fish-speech-1.4/model.pth",
929
+ ],
930
+ value="checkpoints/fish-speech-1.4/model.pth",
931
+ allow_custom_value=True,
932
+ interactive=True,
933
+ )
934
+ with gr.Row(equal_height=False):
935
+ lora_weight = gr.Dropdown(
936
+ label=i18n("LoRA Model to be merged"),
937
+ info=i18n(
938
+ "Type the path or select from the dropdown"
939
+ ),
940
+ choices=[
941
+ str(p)
942
+ for p in Path("results").glob("lora*/**/*.ckpt")
943
+ ],
944
+ allow_custom_value=True,
945
+ interactive=True,
946
+ )
947
+ lora_llama_config = gr.Dropdown(
948
+ label=i18n("LLAMA Model Config"),
949
+ info=i18n(
950
+ "Type the path or select from the dropdown"
951
+ ),
952
+ choices=[
953
+ "text2semantic_finetune",
954
+ ],
955
+ value="text2semantic_finetune",
956
+ allow_custom_value=True,
957
+ )
958
+ with gr.Row(equal_height=False):
959
+ llama_lora_output = gr.Dropdown(
960
+ label=i18n("Output Path"),
961
+ info=i18n(
962
+ "Type the path or select from the dropdown"
963
+ ),
964
+ value="checkpoints/merged",
965
+ choices=["checkpoints/merged"],
966
+ allow_custom_value=True,
967
+ interactive=True,
968
+ )
969
+ with gr.Row(equal_height=False):
970
+ llama_lora_merge_btn = gr.Button(
971
+ value=i18n("Merge"), variant="primary"
972
+ )
973
+
974
+ with gr.Tab(label=i18n("Model Quantization"), id=5):
975
+ with gr.Row(equal_height=False):
976
+ llama_weight_to_quantify = gr.Dropdown(
977
+ label=i18n("Base LLAMA Model"),
978
+ info=i18n(
979
+ "Type the path or select from the dropdown"
980
+ ),
981
+ choices=list_llama_models(),
982
+ value="checkpoints/fish-speech-1.4",
983
+ allow_custom_value=True,
984
+ interactive=True,
985
+ )
986
+ quantify_mode = gr.Dropdown(
987
+ label=i18n("Post-quantification Precision"),
988
+ info=i18n(
989
+ "The lower the quantitative precision, the more the effectiveness may decrease, but the greater the efficiency will increase"
990
+ ),
991
+ choices=["int8", "int4"],
992
+ value="int8",
993
+ allow_custom_value=False,
994
+ interactive=True,
995
+ )
996
+ with gr.Row(equal_height=False):
997
+ llama_quantify_btn = gr.Button(
998
+ value=i18n("Quantify"), variant="primary"
999
+ )
1000
+
1001
+ with gr.Tab(label="Tensorboard", id=6):
1002
+ with gr.Row(equal_height=False):
1003
+ tb_host = gr.Textbox(
1004
+ label=i18n("Tensorboard Host"), value="127.0.0.1"
1005
+ )
1006
+ tb_port = gr.Textbox(
1007
+ label=i18n("Tensorboard Port"), value="11451"
1008
+ )
1009
+ with gr.Row(equal_height=False):
1010
+ tb_dir = gr.Dropdown(
1011
+ label=i18n("Tensorboard Log Path"),
1012
+ allow_custom_value=True,
1013
+ choices=[
1014
+ str(p)
1015
+ for p in Path("results").glob("**/tensorboard/")
1016
+ ],
1017
+ )
1018
+ with gr.Row(equal_height=False):
1019
+ if_tb = gr.Checkbox(
1020
+ label=i18n("Open Tensorboard"),
1021
+ )
1022
+
1023
+ with gr.Tab("\U0001F9E0 " + i18n("Inference Configuration")):
1024
+ with gr.Column():
1025
+ with gr.Row():
1026
+ with gr.Accordion(
1027
+ label="\U0001F5A5 "
1028
+ + i18n("Inference Server Configuration"),
1029
+ open=False,
1030
+ ):
1031
+ with gr.Row():
1032
+ infer_host_textbox = gr.Textbox(
1033
+ label=i18n("WebUI Host"), value="127.0.0.1"
1034
+ )
1035
+ infer_port_textbox = gr.Textbox(
1036
+ label=i18n("WebUI Port"), value="7862"
1037
+ )
1038
+ with gr.Row():
1039
+ infer_decoder_model = gr.Dropdown(
1040
+ label=i18n("Decoder Model Path"),
1041
+ info=i18n(
1042
+ "Type the path or select from the dropdown"
1043
+ ),
1044
+ choices=list_decoder_models(),
1045
+ value="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
1046
+ allow_custom_value=True,
1047
+ )
1048
+ infer_decoder_config = gr.Dropdown(
1049
+ label=i18n("Decoder Model Config"),
1050
+ info=i18n("Changing with the Model Path"),
1051
+ value="firefly_gan_vq",
1052
+ choices=[
1053
+ "firefly_gan_vq",
1054
+ ],
1055
+ allow_custom_value=True,
1056
+ )
1057
+ with gr.Row():
1058
+ infer_llama_model = gr.Dropdown(
1059
+ label=i18n("LLAMA Model Path"),
1060
+ info=i18n(
1061
+ "Type the path or select from the dropdown"
1062
+ ),
1063
+ value="checkpoints/fish-speech-1.4",
1064
+ choices=list_llama_models(),
1065
+ allow_custom_value=True,
1066
+ )
1067
+
1068
+ with gr.Row():
1069
+ infer_compile = gr.Radio(
1070
+ label=i18n("Compile Model"),
1071
+ info=i18n(
1072
+ "Compile the model can significantly reduce the inference time, but will increase cold start time"
1073
+ ),
1074
+ choices=["Yes", "No"],
1075
+ value=(
1076
+ "Yes" if (sys.platform == "linux") else "No"
1077
+ ),
1078
+ interactive=is_module_installed("triton"),
1079
+ )
1080
+
1081
+ with gr.Row():
1082
+ infer_checkbox = gr.Checkbox(
1083
+ label=i18n("Open Inference Server")
1084
+ )
1085
+ infer_error = gr.HTML(label=i18n("Inference Server Error"))
1086
+
1087
+ with gr.Column():
1088
+ train_error = gr.HTML(label=i18n("Training Error"))
1089
+ checkbox_group = gr.CheckboxGroup(
1090
+ label="\U0001F4CA " + i18n("Data Source"),
1091
+ info=i18n(
1092
+ "The path of the input folder on the left or the filelist. Whether checked or not, it will be used for subsequent training in this list."
1093
+ ),
1094
+ elem_classes=["data_src"],
1095
+ )
1096
+ train_box = gr.Textbox(
1097
+ label=i18n("Data Preprocessing Path"),
1098
+ value=str(data_pre_output),
1099
+ interactive=False,
1100
+ )
1101
+ model_box = gr.Textbox(
1102
+ label="\U0001F4BE " + i18n("Model Output Path"),
1103
+ value=str(default_model_output),
1104
+ interactive=False,
1105
+ )
1106
+
1107
+ with gr.Accordion(
1108
+ i18n(
1109
+ "View the status of the preprocessing folder (use the slider to control the depth of the tree)"
1110
+ ),
1111
+ elem_classes=["scrollable-component"],
1112
+ elem_id="file_accordion",
1113
+ ):
1114
+ tree_slider = gr.Slider(
1115
+ minimum=0,
1116
+ maximum=3,
1117
+ value=0,
1118
+ step=1,
1119
+ show_label=False,
1120
+ container=False,
1121
+ )
1122
+ file_markdown = new_explorer(str(data_pre_output), 0)
1123
+ with gr.Row(equal_height=False):
1124
+ admit_btn = gr.Button(
1125
+ "\U00002705 " + i18n("File Preprocessing"),
1126
+ variant="primary",
1127
+ )
1128
+ fresh_btn = gr.Button("\U0001F503", scale=0, min_width=80)
1129
+ help_button = gr.Button("\U00002753", scale=0, min_width=80) # question
1130
+ train_btn = gr.Button(i18n("Start Training"), variant="primary")
1131
+
1132
+ footer = load_data_in_raw("fish_speech/webui/html/footer.html")
1133
+ footer = footer.format(
1134
+ versions=versions_html(),
1135
+ api_docs="https://speech.fish.audio/inference/#http-api",
1136
+ )
1137
+ gr.HTML(footer, elem_id="footer")
1138
+ vqgan_page.select(lambda: "VQGAN", None, model_type_radio)
1139
+ llama_page.select(lambda: "LLAMA", None, model_type_radio)
1140
+ add_button.click(
1141
+ fn=add_item,
1142
+ inputs=[textbox, output_radio, label_radio, if_initial_prompt, initial_prompt],
1143
+ outputs=[checkbox_group, error],
1144
+ )
1145
+ remove_button.click(
1146
+ fn=remove_items, inputs=[checkbox_group], outputs=[checkbox_group, error]
1147
+ )
1148
+ checkbox_group.change(fn=show_selected, inputs=checkbox_group, outputs=[error])
1149
+ help_button.click(
1150
+ fn=None,
1151
+ js='() => { window.open("https://speech.fish.audio/", "newwindow", "height=100, width=400, '
1152
+ 'toolbar=no, menubar=no, scrollbars=no, resizable=no, location=no, status=no")}',
1153
+ )
1154
+ if_label.change(fn=change_label, inputs=[if_label], outputs=[error])
1155
+ if_initial_prompt.change(
1156
+ fn=lambda x: gr.Textbox(value="", interactive=x),
1157
+ inputs=[if_initial_prompt],
1158
+ outputs=[initial_prompt],
1159
+ )
1160
+ train_btn.click(
1161
+ fn=train_process,
1162
+ inputs=[
1163
+ train_box,
1164
+ model_type_radio,
1165
+ # llama config
1166
+ llama_ckpt,
1167
+ llama_base_config,
1168
+ llama_lr_slider,
1169
+ llama_maxsteps_slider,
1170
+ llama_data_num_workers_slider,
1171
+ llama_data_batch_size_slider,
1172
+ llama_data_max_length_slider,
1173
+ llama_precision_dropdown,
1174
+ llama_check_interval_slider,
1175
+ llama_grad_batches,
1176
+ llama_use_speaker,
1177
+ llama_use_lora,
1178
+ ],
1179
+ outputs=[train_error],
1180
+ )
1181
+ if_tb.change(
1182
+ fn=tensorboard_process,
1183
+ inputs=[if_tb, tb_dir, tb_host, tb_port],
1184
+ outputs=[train_error],
1185
+ )
1186
+ tb_dir.change(fn=fresh_tb_dir, inputs=[], outputs=[tb_dir])
1187
+ infer_decoder_model.change(
1188
+ fn=fresh_decoder_model, inputs=[], outputs=[infer_decoder_model]
1189
+ )
1190
+ infer_llama_model.change(
1191
+ fn=fresh_llama_model, inputs=[], outputs=[infer_llama_model]
1192
+ )
1193
+ llama_weight.change(fn=fresh_llama_model, inputs=[], outputs=[llama_weight])
1194
+ admit_btn.click(
1195
+ fn=check_files,
1196
+ inputs=[train_box, tree_slider, label_model, label_device],
1197
+ outputs=[error, file_markdown],
1198
+ )
1199
+ fresh_btn.click(
1200
+ fn=new_explorer, inputs=[train_box, tree_slider], outputs=[file_markdown]
1201
+ )
1202
+ llama_use_lora.change(
1203
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
1204
+ )
1205
+ llama_ckpt.change(
1206
+ fn=fresh_llama_ckpt, inputs=[llama_use_lora], outputs=[llama_ckpt]
1207
+ )
1208
+ lora_weight.change(
1209
+ fn=lambda: gr.Dropdown(choices=list_lora_llama_models()),
1210
+ inputs=[],
1211
+ outputs=[lora_weight],
1212
+ )
1213
+ llama_lora_merge_btn.click(
1214
+ fn=llama_lora_merge,
1215
+ inputs=[llama_weight, lora_llama_config, lora_weight, llama_lora_output],
1216
+ outputs=[train_error],
1217
+ )
1218
+ llama_quantify_btn.click(
1219
+ fn=llama_quantify,
1220
+ inputs=[llama_weight_to_quantify, quantify_mode],
1221
+ outputs=[train_error],
1222
+ )
1223
+ infer_checkbox.change(
1224
+ fn=change_infer,
1225
+ inputs=[
1226
+ infer_checkbox,
1227
+ infer_host_textbox,
1228
+ infer_port_textbox,
1229
+ infer_decoder_model,
1230
+ infer_decoder_config,
1231
+ infer_llama_model,
1232
+ infer_compile,
1233
+ ],
1234
+ outputs=[infer_error],
1235
+ )
1236
+
1237
+ demo.launch(inbrowser=True)
requirements.txt CHANGED
@@ -24,4 +24,5 @@ resampy>=0.4.3
24
  spaces>=0.26.1
25
  einx[torch]==0.2.0
26
  opencc
27
- faster-whisper
 
 
24
  spaces>=0.26.1
25
  einx[torch]==0.2.0
26
  opencc
27
+ faster-whisper
28
+ ormsgpack
tools/api.py CHANGED
@@ -3,21 +3,26 @@ import io
3
  import json
4
  import queue
5
  import random
 
6
  import traceback
7
  import wave
8
  from argparse import ArgumentParser
9
  from http import HTTPStatus
10
  from pathlib import Path
11
- from typing import Annotated, Literal, Optional
12
 
13
- import librosa
14
  import numpy as np
 
15
  import pyrootutils
16
  import soundfile as sf
17
  import torch
 
 
18
  from kui.asgi import (
19
  Body,
 
20
  HTTPException,
 
21
  HttpView,
22
  JSONResponse,
23
  Kui,
@@ -26,13 +31,16 @@ from kui.asgi import (
26
  )
27
  from kui.asgi.routing import MultimethodRoutes
28
  from loguru import logger
29
- from pydantic import BaseModel, Field
30
 
31
  pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
32
 
33
  # from fish_speech.models.vqgan.lit_module import VQGAN
34
  from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
35
- from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
 
 
 
36
  from tools.llama.generate import (
37
  GenerateRequest,
38
  GenerateResponse,
@@ -80,13 +88,21 @@ async def other_exception_handler(exc: "Exception"):
80
 
81
  def load_audio(reference_audio, sr):
82
  if len(reference_audio) > 255 or not Path(reference_audio).exists():
83
- try:
84
- audio_data = base64.b64decode(reference_audio)
85
- reference_audio = io.BytesIO(audio_data)
86
- except base64.binascii.Error:
87
- raise ValueError("Invalid path or base64 string")
88
 
89
- audio, _ = librosa.load(reference_audio, sr=sr, mono=True)
 
 
 
 
 
 
 
 
 
 
 
90
  return audio
91
 
92
 
@@ -132,7 +148,7 @@ def decode_vq_tokens(
132
  return decoder_model.decode(
133
  indices=codes[None],
134
  feature_lengths=feature_lengths,
135
- ).squeeze()
136
 
137
  raise ValueError(f"Unknown model type: {type(decoder_model)}")
138
 
@@ -140,58 +156,6 @@ def decode_vq_tokens(
140
  routes = MultimethodRoutes(base_class=HttpView)
141
 
142
 
143
- def get_random_paths(base_path, data, speaker, emotion):
144
- if base_path and data and speaker and emotion and (Path(base_path).exists()):
145
- if speaker in data and emotion in data[speaker]:
146
- files = data[speaker][emotion]
147
- lab_files = [f for f in files if f.endswith(".lab")]
148
- wav_files = [f for f in files if f.endswith(".wav")]
149
-
150
- if lab_files and wav_files:
151
- selected_lab = random.choice(lab_files)
152
- selected_wav = random.choice(wav_files)
153
-
154
- lab_path = Path(base_path) / speaker / emotion / selected_lab
155
- wav_path = Path(base_path) / speaker / emotion / selected_wav
156
- if lab_path.exists() and wav_path.exists():
157
- return lab_path, wav_path
158
-
159
- return None, None
160
-
161
-
162
- def load_json(json_file):
163
- if not json_file:
164
- logger.info("Not using a json file")
165
- return None
166
- try:
167
- with open(json_file, "r", encoding="utf-8") as file:
168
- data = json.load(file)
169
- except FileNotFoundError:
170
- logger.warning(f"ref json not found: {json_file}")
171
- data = None
172
- except Exception as e:
173
- logger.warning(f"Loading json failed: {e}")
174
- data = None
175
- return data
176
-
177
-
178
- class InvokeRequest(BaseModel):
179
- text: str = "你说的对, 但是原神是一款由米哈游自主研发的开放世界手游."
180
- reference_text: Optional[str] = None
181
- reference_audio: Optional[str] = None
182
- max_new_tokens: int = 1024
183
- chunk_length: Annotated[int, Field(ge=0, le=500, strict=True)] = 100
184
- top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
185
- repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
186
- temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
187
- emotion: Optional[str] = None
188
- format: Literal["wav", "mp3", "flac"] = "wav"
189
- streaming: bool = False
190
- ref_json: Optional[str] = "ref_data.json"
191
- ref_base: Optional[str] = "ref_data"
192
- speaker: Optional[str] = None
193
-
194
-
195
  def get_content_type(audio_format):
196
  if audio_format == "wav":
197
  return "audio/wav"
@@ -204,35 +168,52 @@ def get_content_type(audio_format):
204
 
205
 
206
  @torch.inference_mode()
207
- def inference(req: InvokeRequest):
208
- # Parse reference audio aka prompt
209
- prompt_tokens = None
210
-
211
- ref_data = load_json(req.ref_json)
212
- ref_base = req.ref_base
213
-
214
- lab_path, wav_path = get_random_paths(ref_base, ref_data, req.speaker, req.emotion)
215
-
216
- if lab_path and wav_path:
217
- with open(lab_path, "r", encoding="utf-8") as lab_file:
218
- ref_text = lab_file.read()
219
- req.reference_audio = wav_path
220
- req.reference_text = ref_text
221
- logger.info("ref_path: " + str(wav_path))
222
- logger.info("ref_text: " + ref_text)
223
-
224
- # Parse reference audio aka prompt
225
- prompt_tokens = encode_reference(
226
- decoder_model=decoder_model,
227
- reference_audio=req.reference_audio,
228
- enable_reference_audio=req.reference_audio is not None,
229
- )
230
- logger.info(f"ref_text: {req.reference_text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
231
  # LLAMA Inference
232
  request = dict(
233
  device=decoder_model.device,
234
  max_new_tokens=req.max_new_tokens,
235
- text=req.text,
 
 
 
 
236
  top_p=req.top_p,
237
  repetition_penalty=req.repetition_penalty,
238
  temperature=req.temperature,
@@ -241,7 +222,7 @@ def inference(req: InvokeRequest):
241
  chunk_length=req.chunk_length,
242
  max_length=2048,
243
  prompt_tokens=prompt_tokens,
244
- prompt_text=req.reference_text,
245
  )
246
 
247
  response_queue = queue.Queue()
@@ -266,7 +247,7 @@ def inference(req: InvokeRequest):
266
  if result.action == "next":
267
  break
268
 
269
- with torch.autocast(
270
  device_type=decoder_model.device.type, dtype=args.precision
271
  ):
272
  fake_audios = decode_vq_tokens(
@@ -294,40 +275,7 @@ def inference(req: InvokeRequest):
294
  yield fake_audios
295
 
296
 
297
- def auto_rerank_inference(req: InvokeRequest, use_auto_rerank: bool = True):
298
- if not use_auto_rerank:
299
- # 如果不使用 auto_rerank,直接调用原始的 inference 函数
300
- return inference(req)
301
-
302
- zh_model, en_model = load_model()
303
- max_attempts = 5
304
- best_wer = float("inf")
305
- best_audio = None
306
-
307
- for attempt in range(max_attempts):
308
- # 调用原始的 inference 函数
309
- audio_generator = inference(req)
310
- fake_audios = next(audio_generator)
311
-
312
- asr_result = batch_asr(
313
- zh_model if is_chinese(req.text) else en_model, [fake_audios], 44100
314
- )[0]
315
- wer = calculate_wer(req.text, asr_result["text"])
316
-
317
- if wer <= 0.1 and not asr_result["huge_gap"]:
318
- return fake_audios
319
-
320
- if wer < best_wer:
321
- best_wer = wer
322
- best_audio = fake_audios
323
-
324
- if attempt == max_attempts - 1:
325
- break
326
-
327
- return best_audio
328
-
329
-
330
- async def inference_async(req: InvokeRequest):
331
  for chunk in inference(req):
332
  yield chunk
333
 
@@ -336,9 +284,9 @@ async def buffer_to_async_generator(buffer):
336
  yield buffer
337
 
338
 
339
- @routes.http.post("/v1/invoke")
340
  async def api_invoke_model(
341
- req: Annotated[InvokeRequest, Body(exclusive=True)],
342
  ):
343
  """
344
  Invoke model and generate audio
@@ -397,19 +345,19 @@ def parse_args():
397
  parser.add_argument(
398
  "--llama-checkpoint-path",
399
  type=str,
400
- default="checkpoints/fish-speech-1.2-sft",
401
  )
402
  parser.add_argument(
403
  "--decoder-checkpoint-path",
404
  type=str,
405
- default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
406
  )
407
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
408
  parser.add_argument("--device", type=str, default="cuda")
409
  parser.add_argument("--half", action="store_true")
410
  parser.add_argument("--compile", action="store_true")
411
  parser.add_argument("--max-text-length", type=int, default=0)
412
- parser.add_argument("--listen", type=str, default="127.0.0.1:8000")
413
  parser.add_argument("--workers", type=int, default=1)
414
  parser.add_argument("--use-auto-rerank", type=bool, default=True)
415
 
@@ -423,18 +371,30 @@ openapi = OpenAPI(
423
  },
424
  ).routes
425
 
 
 
 
 
 
 
 
 
 
 
 
 
426
  app = Kui(
427
  routes=routes + openapi[1:], # Remove the default route
428
  exception_handlers={
429
  HTTPException: http_execption_handler,
430
  Exception: other_exception_handler,
431
  },
 
432
  cors_config={},
433
  )
434
 
435
 
436
  if __name__ == "__main__":
437
- import threading
438
 
439
  import uvicorn
440
 
@@ -461,18 +421,16 @@ if __name__ == "__main__":
461
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
462
  list(
463
  inference(
464
- InvokeRequest(
465
  text="Hello world.",
466
- reference_text=None,
467
- reference_audio=None,
468
  max_new_tokens=0,
469
  top_p=0.7,
470
  repetition_penalty=1.2,
471
  temperature=0.7,
472
  emotion=None,
473
  format="wav",
474
- ref_base=None,
475
- ref_json=None,
476
  )
477
  )
478
  )
 
3
  import json
4
  import queue
5
  import random
6
+ import sys
7
  import traceback
8
  import wave
9
  from argparse import ArgumentParser
10
  from http import HTTPStatus
11
  from pathlib import Path
12
+ from typing import Annotated, Any, Literal, Optional
13
 
 
14
  import numpy as np
15
+ import ormsgpack
16
  import pyrootutils
17
  import soundfile as sf
18
  import torch
19
+ import torchaudio
20
+ from baize.datastructures import ContentType
21
  from kui.asgi import (
22
  Body,
23
+ FactoryClass,
24
  HTTPException,
25
+ HttpRequest,
26
  HttpView,
27
  JSONResponse,
28
  Kui,
 
31
  )
32
  from kui.asgi.routing import MultimethodRoutes
33
  from loguru import logger
34
+ from pydantic import BaseModel, Field, conint
35
 
36
  pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
37
 
38
  # from fish_speech.models.vqgan.lit_module import VQGAN
39
  from fish_speech.models.vqgan.modules.firefly import FireflyArchitecture
40
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
41
+ from fish_speech.utils import autocast_exclude_mps
42
+ from tools.commons import ServeReferenceAudio, ServeTTSRequest
43
+ from tools.file import AUDIO_EXTENSIONS, audio_to_bytes, list_files, read_ref_text
44
  from tools.llama.generate import (
45
  GenerateRequest,
46
  GenerateResponse,
 
88
 
89
  def load_audio(reference_audio, sr):
90
  if len(reference_audio) > 255 or not Path(reference_audio).exists():
91
+ audio_data = reference_audio
92
+ reference_audio = io.BytesIO(audio_data)
 
 
 
93
 
94
+ waveform, original_sr = torchaudio.load(
95
+ reference_audio, backend="sox" if sys.platform == "linux" else "soundfile"
96
+ )
97
+
98
+ if waveform.shape[0] > 1:
99
+ waveform = torch.mean(waveform, dim=0, keepdim=True)
100
+
101
+ if original_sr != sr:
102
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sr, new_freq=sr)
103
+ waveform = resampler(waveform)
104
+
105
+ audio = waveform.squeeze().numpy()
106
  return audio
107
 
108
 
 
148
  return decoder_model.decode(
149
  indices=codes[None],
150
  feature_lengths=feature_lengths,
151
+ )[0].squeeze()
152
 
153
  raise ValueError(f"Unknown model type: {type(decoder_model)}")
154
 
 
156
  routes = MultimethodRoutes(base_class=HttpView)
157
 
158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
  def get_content_type(audio_format):
160
  if audio_format == "wav":
161
  return "audio/wav"
 
168
 
169
 
170
  @torch.inference_mode()
171
+ def inference(req: ServeTTSRequest):
172
+
173
+ idstr: str | None = req.reference_id
174
+ if idstr is not None:
175
+ ref_folder = Path("references") / idstr
176
+ ref_folder.mkdir(parents=True, exist_ok=True)
177
+ ref_audios = list_files(
178
+ ref_folder, AUDIO_EXTENSIONS, recursive=True, sort=False
179
+ )
180
+ prompt_tokens = [
181
+ encode_reference(
182
+ decoder_model=decoder_model,
183
+ reference_audio=audio_to_bytes(str(ref_audio)),
184
+ enable_reference_audio=True,
185
+ )
186
+ for ref_audio in ref_audios
187
+ ]
188
+ prompt_texts = [
189
+ read_ref_text(str(ref_audio.with_suffix(".lab")))
190
+ for ref_audio in ref_audios
191
+ ]
192
+
193
+ else:
194
+ # Parse reference audio aka prompt
195
+ refs = req.references
196
+ if refs is None:
197
+ refs = []
198
+ prompt_tokens = [
199
+ encode_reference(
200
+ decoder_model=decoder_model,
201
+ reference_audio=ref.audio,
202
+ enable_reference_audio=True,
203
+ )
204
+ for ref in refs
205
+ ]
206
+ prompt_texts = [ref.text for ref in refs]
207
+
208
  # LLAMA Inference
209
  request = dict(
210
  device=decoder_model.device,
211
  max_new_tokens=req.max_new_tokens,
212
+ text=(
213
+ req.text
214
+ if not req.normalize
215
+ else ChnNormedText(raw_text=req.text).normalize()
216
+ ),
217
  top_p=req.top_p,
218
  repetition_penalty=req.repetition_penalty,
219
  temperature=req.temperature,
 
222
  chunk_length=req.chunk_length,
223
  max_length=2048,
224
  prompt_tokens=prompt_tokens,
225
+ prompt_text=prompt_texts,
226
  )
227
 
228
  response_queue = queue.Queue()
 
247
  if result.action == "next":
248
  break
249
 
250
+ with autocast_exclude_mps(
251
  device_type=decoder_model.device.type, dtype=args.precision
252
  ):
253
  fake_audios = decode_vq_tokens(
 
275
  yield fake_audios
276
 
277
 
278
+ async def inference_async(req: ServeTTSRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  for chunk in inference(req):
280
  yield chunk
281
 
 
284
  yield buffer
285
 
286
 
287
+ @routes.http.post("/v1/tts")
288
  async def api_invoke_model(
289
+ req: Annotated[ServeTTSRequest, Body(exclusive=True)],
290
  ):
291
  """
292
  Invoke model and generate audio
 
345
  parser.add_argument(
346
  "--llama-checkpoint-path",
347
  type=str,
348
+ default="checkpoints/fish-speech-1.4",
349
  )
350
  parser.add_argument(
351
  "--decoder-checkpoint-path",
352
  type=str,
353
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
354
  )
355
  parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
356
  parser.add_argument("--device", type=str, default="cuda")
357
  parser.add_argument("--half", action="store_true")
358
  parser.add_argument("--compile", action="store_true")
359
  parser.add_argument("--max-text-length", type=int, default=0)
360
+ parser.add_argument("--listen", type=str, default="127.0.0.1:8080")
361
  parser.add_argument("--workers", type=int, default=1)
362
  parser.add_argument("--use-auto-rerank", type=bool, default=True)
363
 
 
371
  },
372
  ).routes
373
 
374
+
375
+ class MsgPackRequest(HttpRequest):
376
+ async def data(self) -> Annotated[Any, ContentType("application/msgpack")]:
377
+ if self.content_type == "application/msgpack":
378
+ return ormsgpack.unpackb(await self.body)
379
+
380
+ raise HTTPException(
381
+ HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
382
+ headers={"Accept": "application/msgpack"},
383
+ )
384
+
385
+
386
  app = Kui(
387
  routes=routes + openapi[1:], # Remove the default route
388
  exception_handlers={
389
  HTTPException: http_execption_handler,
390
  Exception: other_exception_handler,
391
  },
392
+ factory_class=FactoryClass(http=MsgPackRequest),
393
  cors_config={},
394
  )
395
 
396
 
397
  if __name__ == "__main__":
 
398
 
399
  import uvicorn
400
 
 
421
  # Dry run to check if the model is loaded correctly and avoid the first-time latency
422
  list(
423
  inference(
424
+ ServeTTSRequest(
425
  text="Hello world.",
426
+ references=[],
427
+ reference_id=None,
428
  max_new_tokens=0,
429
  top_p=0.7,
430
  repetition_penalty=1.2,
431
  temperature=0.7,
432
  emotion=None,
433
  format="wav",
 
 
434
  )
435
  )
436
  )
tools/commons.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Literal, Optional
2
+
3
+ from pydantic import BaseModel, Field, conint
4
+
5
+
6
+ class ServeReferenceAudio(BaseModel):
7
+ audio: bytes
8
+ text: str
9
+
10
+
11
+ class ServeTTSRequest(BaseModel):
12
+ text: str
13
+ chunk_length: Annotated[int, conint(ge=100, le=300, strict=True)] = 200
14
+ # Audio format
15
+ format: Literal["wav", "pcm", "mp3"] = "wav"
16
+ mp3_bitrate: Literal[64, 128, 192] = 128
17
+ # References audios for in-context learning
18
+ references: list[ServeReferenceAudio] = []
19
+ # Reference id
20
+ # For example, if you want use https://fish.audio/m/7f92f8afb8ec43bf81429cc1c9199cb1/
21
+ # Just pass 7f92f8afb8ec43bf81429cc1c9199cb1
22
+ reference_id: str | None = None
23
+ # Normalize text for en & zh, this increase stability for numbers
24
+ normalize: bool = True
25
+ mp3_bitrate: Optional[int] = 64
26
+ opus_bitrate: Optional[int] = -1000
27
+ # Balance mode will reduce latency to 300ms, but may decrease stability
28
+ latency: Literal["normal", "balanced"] = "normal"
29
+ # not usually used below
30
+ streaming: bool = False
31
+ emotion: Optional[str] = None
32
+ max_new_tokens: int = 1024
33
+ top_p: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
34
+ repetition_penalty: Annotated[float, Field(ge=0.9, le=2.0, strict=True)] = 1.2
35
+ temperature: Annotated[float, Field(ge=0.1, le=1.0, strict=True)] = 0.7
tools/download_models.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from huggingface_hub import hf_hub_download
4
+
5
+
6
+ # Download
7
+ def check_and_download_files(repo_id, file_list, local_dir):
8
+ os.makedirs(local_dir, exist_ok=True)
9
+ for file in file_list:
10
+ file_path = os.path.join(local_dir, file)
11
+ if not os.path.exists(file_path):
12
+ print(f"{file} 不存在,从 Hugging Face 仓库下载...")
13
+ hf_hub_download(
14
+ repo_id=repo_id,
15
+ filename=file,
16
+ resume_download=True,
17
+ local_dir=local_dir,
18
+ local_dir_use_symlinks=False,
19
+ )
20
+ else:
21
+ print(f"{file} 已存在,跳过下载。")
22
+
23
+
24
+ # 1st
25
+ repo_id_1 = "fishaudio/fish-speech-1.4"
26
+ local_dir_1 = "./checkpoints/fish-speech-1.4"
27
+ files_1 = [
28
+ "model.pth",
29
+ "README.md",
30
+ "special_tokens_map.json",
31
+ "tokenizer_config.json",
32
+ "tokenizer.json",
33
+ "config.json",
34
+ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
35
+ ]
36
+
37
+ # 3rd
38
+ repo_id_3 = "fishaudio/fish-speech-1"
39
+ local_dir_3 = "./"
40
+ files_3 = [
41
+ "ffmpeg.exe",
42
+ "ffprobe.exe",
43
+ ]
44
+
45
+ # 4th
46
+ repo_id_4 = "SpicyqSama007/fish-speech-packed"
47
+ local_dir_4 = "./"
48
+ files_4 = [
49
+ "asr-label-win-x64.exe",
50
+ ]
51
+
52
+ check_and_download_files(repo_id_1, files_1, local_dir_1)
53
+
54
+ check_and_download_files(repo_id_3, files_3, local_dir_3)
55
+ check_and_download_files(repo_id_4, files_4, local_dir_4)
tools/extract_model.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import click
2
+ import torch
3
+ from loguru import logger
4
+
5
+
6
+ @click.command()
7
+ @click.argument("model_path")
8
+ @click.argument("output_path")
9
+ def main(model_path, output_path):
10
+ if model_path == output_path:
11
+ logger.error("Model path and output path are the same")
12
+ return
13
+
14
+ logger.info(f"Loading model from {model_path}")
15
+ state_dict = torch.load(model_path, map_location="cpu")["state_dict"]
16
+ torch.save(state_dict, output_path)
17
+ logger.info(f"Model saved to {output_path}")
18
+
19
+
20
+ if __name__ == "__main__":
21
+ main()
tools/file.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ from loguru import logger
6
+ from natsort import natsorted
7
+
8
+ AUDIO_EXTENSIONS = {
9
+ ".mp3",
10
+ ".wav",
11
+ ".flac",
12
+ ".ogg",
13
+ ".m4a",
14
+ ".wma",
15
+ ".aac",
16
+ ".aiff",
17
+ ".aif",
18
+ ".aifc",
19
+ }
20
+
21
+ VIDEO_EXTENSIONS = {
22
+ ".mp4",
23
+ ".avi",
24
+ }
25
+
26
+
27
+ def audio_to_bytes(file_path):
28
+ if not file_path or not Path(file_path).exists():
29
+ return None
30
+ with open(file_path, "rb") as wav_file:
31
+ wav = wav_file.read()
32
+ return wav
33
+
34
+
35
+ def read_ref_text(ref_text):
36
+ path = Path(ref_text)
37
+ if path.exists() and path.is_file():
38
+ with path.open("r", encoding="utf-8") as file:
39
+ return file.read()
40
+ return ref_text
41
+
42
+
43
+ def list_files(
44
+ path: Union[Path, str],
45
+ extensions: set[str] = None,
46
+ recursive: bool = False,
47
+ sort: bool = True,
48
+ ) -> list[Path]:
49
+ """List files in a directory.
50
+
51
+ Args:
52
+ path (Path): Path to the directory.
53
+ extensions (set, optional): Extensions to filter. Defaults to None.
54
+ recursive (bool, optional): Whether to search recursively. Defaults to False.
55
+ sort (bool, optional): Whether to sort the files. Defaults to True.
56
+
57
+ Returns:
58
+ list: List of files.
59
+ """
60
+
61
+ if isinstance(path, str):
62
+ path = Path(path)
63
+
64
+ if not path.exists():
65
+ raise FileNotFoundError(f"Directory {path} does not exist.")
66
+
67
+ files = [file for ext in extensions for file in path.rglob(f"*{ext}")]
68
+
69
+ if sort:
70
+ files = natsorted(files)
71
+
72
+ return files
73
+
74
+
75
+ def load_filelist(path: Path | str) -> list[tuple[Path, str, str, str]]:
76
+ """
77
+ Load a Bert-VITS2 style filelist.
78
+ """
79
+
80
+ files = set()
81
+ results = []
82
+ count_duplicated, count_not_found = 0, 0
83
+
84
+ LANGUAGE_TO_LANGUAGES = {
85
+ "zh": ["zh", "en"],
86
+ "jp": ["jp", "en"],
87
+ "en": ["en"],
88
+ }
89
+
90
+ with open(path, "r", encoding="utf-8") as f:
91
+ for line in f.readlines():
92
+ splits = line.strip().split("|", maxsplit=3)
93
+ if len(splits) != 4:
94
+ logger.warning(f"Invalid line: {line}")
95
+ continue
96
+
97
+ filename, speaker, language, text = splits
98
+ file = Path(filename)
99
+ language = language.strip().lower()
100
+
101
+ if language == "ja":
102
+ language = "jp"
103
+
104
+ assert language in ["zh", "jp", "en"], f"Invalid language {language}"
105
+ languages = LANGUAGE_TO_LANGUAGES[language]
106
+
107
+ if file in files:
108
+ logger.warning(f"Duplicated file: {file}")
109
+ count_duplicated += 1
110
+ continue
111
+
112
+ if not file.exists():
113
+ logger.warning(f"File not found: {file}")
114
+ count_not_found += 1
115
+ continue
116
+
117
+ results.append((file, speaker, languages, text))
118
+
119
+ if count_duplicated > 0:
120
+ logger.warning(f"Total duplicated files: {count_duplicated}")
121
+
122
+ if count_not_found > 0:
123
+ logger.warning(f"Total files not found: {count_not_found}")
124
+
125
+ return results
tools/llama/build_dataset.py CHANGED
@@ -13,7 +13,7 @@ from tqdm import tqdm
13
 
14
  from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
  from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
- from fish_speech.utils.file import load_filelist
17
 
18
  # To avoid CPU overload
19
  os.environ["MKL_NUM_THREADS"] = "1"
 
13
 
14
  from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData
15
  from fish_speech.datasets.protos.text_data_stream import pack_pb_stream
16
+ from tools.file import load_filelist
17
 
18
  # To avoid CPU overload
19
  os.environ["MKL_NUM_THREADS"] = "1"
tools/llama/generate.py CHANGED
@@ -2,6 +2,7 @@ import os
2
  import queue
3
  import threading
4
  import time
 
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
  from typing import Literal, Optional, Tuple, Union
@@ -93,15 +94,20 @@ def decode_one_token_ar(
93
  **sampling_kwargs,
94
  ) -> torch.Tensor:
95
  x = model.forward_generate(x, input_pos)
 
 
 
 
 
 
96
  codebooks = [
97
  sample(
98
  x.logits,
99
- previous_tokens=(
100
- previous_tokens[0] if previous_tokens is not None else None
101
- ), # Disable repetition penalty for the token codebook
102
- **sampling_kwargs,
103
  )[0]
104
  ]
 
105
  x = x.hidden_states
106
 
107
  # Cleanup the cache
@@ -136,11 +142,16 @@ def decode_one_token_naive(
136
  ) -> torch.Tensor:
137
  x = model.forward_generate(x, input_pos)
138
 
 
 
 
 
 
139
  codebooks = [
140
  sample(
141
- x.token_logits,
142
  previous_tokens=None, # Disable repetition penalty for the token codebook
143
- **sampling_kwargs,
144
  )[0]
145
  ]
146
 
@@ -181,8 +192,12 @@ def decode_n_tokens(
181
  else:
182
  window = previous_tokens[:, i - win_size : i]
183
 
184
- with torch.backends.cuda.sdp_kernel(
185
- enable_flash=False, enable_mem_efficient=False, enable_math=True
 
 
 
 
186
  ): # Actually better for Inductor to codegen attention here
187
  next_token = decode_one_token(
188
  model=model,
@@ -356,7 +371,10 @@ def load_model(checkpoint_path, device, precision, compile=False):
356
  if compile:
357
  logger.info("Compiling function...")
358
  decode_one_token = torch.compile(
359
- decode_one_token, mode="reduce-overhead", fullgraph=True
 
 
 
360
  )
361
 
362
  return model.eval(), decode_one_token
@@ -604,7 +622,7 @@ def launch_thread_safe_queue(
604
  @click.option(
605
  "--checkpoint-path",
606
  type=click.Path(path_type=Path, exists=True),
607
- default="checkpoints/fish-speech-1.2-sft",
608
  )
609
  @click.option("--device", type=str, default="cuda")
610
  @click.option("--compile/--no-compile", default=False)
 
2
  import queue
3
  import threading
4
  import time
5
+ from contextlib import nullcontext
6
  from dataclasses import dataclass
7
  from pathlib import Path
8
  from typing import Literal, Optional, Tuple, Union
 
94
  **sampling_kwargs,
95
  ) -> torch.Tensor:
96
  x = model.forward_generate(x, input_pos)
97
+
98
+ sampling_kwargs_main = sampling_kwargs.copy()
99
+ sampling_kwargs_main["temperature"] = 0.1
100
+ sampling_kwargs_main["top_p"] = 0.1
101
+ sampling_kwargs_main["repetition_penalty"] = 1.0
102
+
103
  codebooks = [
104
  sample(
105
  x.logits,
106
+ previous_tokens=None, # Disable repetition penalty for the token codebook
107
+ **sampling_kwargs_main,
 
 
108
  )[0]
109
  ]
110
+
111
  x = x.hidden_states
112
 
113
  # Cleanup the cache
 
142
  ) -> torch.Tensor:
143
  x = model.forward_generate(x, input_pos)
144
 
145
+ sampling_kwargs_main = sampling_kwargs.copy()
146
+ sampling_kwargs_main["temperature"] = 0.1
147
+ sampling_kwargs_main["top_p"] = 0.1
148
+ sampling_kwargs_main["repetition_penalty"] = 1.0
149
+
150
  codebooks = [
151
  sample(
152
+ x.logits,
153
  previous_tokens=None, # Disable repetition penalty for the token codebook
154
+ **sampling_kwargs_main,
155
  )[0]
156
  ]
157
 
 
192
  else:
193
  window = previous_tokens[:, i - win_size : i]
194
 
195
+ with (
196
+ torch.backends.cuda.sdp_kernel(
197
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
198
+ )
199
+ if torch.cuda.is_available()
200
+ else nullcontext()
201
  ): # Actually better for Inductor to codegen attention here
202
  next_token = decode_one_token(
203
  model=model,
 
371
  if compile:
372
  logger.info("Compiling function...")
373
  decode_one_token = torch.compile(
374
+ decode_one_token,
375
+ fullgraph=True,
376
+ backend="inductor" if torch.cuda.is_available() else "aot_eager",
377
+ mode="reduce-overhead" if torch.cuda.is_available() else None,
378
  )
379
 
380
  return model.eval(), decode_one_token
 
622
  @click.option(
623
  "--checkpoint-path",
624
  type=click.Path(path_type=Path, exists=True),
625
+ default="checkpoints/fish-speech-1.4",
626
  )
627
  @click.option("--device", type=str, default="cuda")
628
  @click.option("--compile/--no-compile", default=False)
tools/llama/merge_lora.py CHANGED
@@ -15,7 +15,7 @@ from fish_speech.models.text2semantic.lora import get_merged_state_dict
15
 
16
  @click.command()
17
  @click.option("--lora-config", type=str, default="r_8_alpha_16")
18
- @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.2-sft")
19
  @click.option("--lora-weight", type=str, required=True)
20
  @click.option("--output", type=str, required=True)
21
  def merge(lora_config, base_weight, lora_weight, output):
 
15
 
16
  @click.command()
17
  @click.option("--lora-config", type=str, default="r_8_alpha_16")
18
+ @click.option("--base-weight", type=str, default="checkpoints/fish-speech-1.4")
19
  @click.option("--lora-weight", type=str, required=True)
20
  @click.option("--output", type=str, required=True)
21
  def merge(lora_config, base_weight, lora_weight, output):
tools/llama/quantize.py CHANGED
@@ -428,7 +428,7 @@ def generate_folder_name():
428
  @click.option(
429
  "--checkpoint-path",
430
  type=click.Path(path_type=Path, exists=True),
431
- default="checkpoints/fish-speech-1.2-sft",
432
  )
433
  @click.option(
434
  "--mode", type=str, default="int8", help="type of quantization to perform"
@@ -451,7 +451,7 @@ def quantize(checkpoint_path: Path, mode: str, groupsize: int, timestamp: str) -
451
  precision=precision,
452
  compile=False,
453
  )
454
- vq_model = "firefly-gan-vq-fsq-4x1024-42hz-generator.pth"
455
  now = timestamp if timestamp != "None" else generate_folder_name()
456
 
457
  if mode == "int8":
 
428
  @click.option(
429
  "--checkpoint-path",
430
  type=click.Path(path_type=Path, exists=True),
431
+ default="checkpoints/fish-speech-1.4",
432
  )
433
  @click.option(
434
  "--mode", type=str, default="int8", help="type of quantization to perform"
 
451
  precision=precision,
452
  compile=False,
453
  )
454
+ vq_model = "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
455
  now = timestamp if timestamp != "None" else generate_folder_name()
456
 
457
  if mode == "int8":
tools/msgpack_api.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import ormsgpack
3
+
4
+ from tools.commons import ServeReferenceAudio, ServeTTSRequest
5
+
6
+ # priority: ref_id > references
7
+ request = ServeTTSRequest(
8
+ text="你说的对, 但是原神是一款由米哈游自主研发的开放世界手游.",
9
+ # reference_id="114514",
10
+ references=[
11
+ ServeReferenceAudio(
12
+ audio=open("lengyue.wav", "rb").read(),
13
+ text=open("lengyue.lab", "r", encoding="utf-8").read(),
14
+ )
15
+ ],
16
+ streaming=True,
17
+ )
18
+
19
+ with (
20
+ httpx.Client() as client,
21
+ open("hello.wav", "wb") as f,
22
+ ):
23
+ with client.stream(
24
+ "POST",
25
+ "http://127.0.0.1:8080/v1/tts",
26
+ content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
27
+ headers={
28
+ "authorization": "Bearer YOUR_API_KEY",
29
+ "content-type": "application/msgpack",
30
+ },
31
+ timeout=None,
32
+ ) as response:
33
+ for chunk in response.iter_bytes():
34
+ f.write(chunk)
tools/post_api.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import wave
4
+
5
+ import ormsgpack
6
+ import pyaudio
7
+ import requests
8
+ from pydub import AudioSegment
9
+ from pydub.playback import play
10
+
11
+ from tools.commons import ServeReferenceAudio, ServeTTSRequest
12
+ from tools.file import audio_to_bytes, read_ref_text
13
+
14
+
15
+ def parse_args():
16
+
17
+ parser = argparse.ArgumentParser(
18
+ description="Send a WAV file and text to a server and receive synthesized audio."
19
+ )
20
+
21
+ parser.add_argument(
22
+ "--url",
23
+ "-u",
24
+ type=str,
25
+ default="http://127.0.0.1:8080/v1/tts",
26
+ help="URL of the server",
27
+ )
28
+ parser.add_argument(
29
+ "--text", "-t", type=str, required=True, help="Text to be synthesized"
30
+ )
31
+ parser.add_argument(
32
+ "--reference_id",
33
+ "-id",
34
+ type=str,
35
+ default=None,
36
+ help="ID of the reference model o be used for the speech",
37
+ )
38
+ parser.add_argument(
39
+ "--reference_audio",
40
+ "-ra",
41
+ type=str,
42
+ nargs="+",
43
+ default=None,
44
+ help="Path to the WAV file",
45
+ )
46
+ parser.add_argument(
47
+ "--reference_text",
48
+ "-rt",
49
+ type=str,
50
+ nargs="+",
51
+ default=None,
52
+ help="Reference text for voice synthesis",
53
+ )
54
+ parser.add_argument(
55
+ "--output",
56
+ "-o",
57
+ type=str,
58
+ default="generated_audio",
59
+ help="Output audio file name",
60
+ )
61
+ parser.add_argument(
62
+ "--play",
63
+ type=bool,
64
+ default=True,
65
+ help="Whether to play audio after receiving data",
66
+ )
67
+ parser.add_argument("--normalize", type=bool, default=True)
68
+ parser.add_argument(
69
+ "--format", type=str, choices=["wav", "mp3", "flac"], default="wav"
70
+ )
71
+ parser.add_argument("--mp3_bitrate", type=int, default=64)
72
+ parser.add_argument("--opus_bitrate", type=int, default=-1000)
73
+ parser.add_argument("--latency", type=str, default="normal", help="延迟选项")
74
+ parser.add_argument(
75
+ "--max_new_tokens",
76
+ type=int,
77
+ default=1024,
78
+ help="Maximum new tokens to generate",
79
+ )
80
+ parser.add_argument(
81
+ "--chunk_length", type=int, default=100, help="Chunk length for synthesis"
82
+ )
83
+ parser.add_argument(
84
+ "--top_p", type=float, default=0.7, help="Top-p sampling for synthesis"
85
+ )
86
+ parser.add_argument(
87
+ "--repetition_penalty",
88
+ type=float,
89
+ default=1.2,
90
+ help="Repetition penalty for synthesis",
91
+ )
92
+ parser.add_argument(
93
+ "--temperature", type=float, default=0.7, help="Temperature for sampling"
94
+ )
95
+ parser.add_argument(
96
+ "--speaker", type=str, default=None, help="Speaker ID for voice synthesis"
97
+ )
98
+ parser.add_argument("--emotion", type=str, default=None, help="Speaker's Emotion")
99
+ parser.add_argument(
100
+ "--streaming", type=bool, default=False, help="Enable streaming response"
101
+ )
102
+ parser.add_argument(
103
+ "--channels", type=int, default=1, help="Number of audio channels"
104
+ )
105
+ parser.add_argument("--rate", type=int, default=44100, help="Sample rate for audio")
106
+
107
+ return parser.parse_args()
108
+
109
+
110
+ if __name__ == "__main__":
111
+
112
+ args = parse_args()
113
+
114
+ idstr: str | None = args.reference_id
115
+ # priority: ref_id > [{text, audio},...]
116
+ if idstr is None:
117
+ ref_audios = args.reference_audio
118
+ ref_texts = args.reference_text
119
+ if ref_audios is None:
120
+ byte_audios = []
121
+ else:
122
+ byte_audios = [audio_to_bytes(ref_audio) for ref_audio in ref_audios]
123
+ if ref_texts is None:
124
+ ref_texts = []
125
+ else:
126
+ ref_texts = [read_ref_text(ref_text) for ref_text in ref_texts]
127
+ else:
128
+ byte_audios = []
129
+ ref_texts = []
130
+ pass # in api.py
131
+
132
+ data = {
133
+ "text": args.text,
134
+ "references": [
135
+ ServeReferenceAudio(audio=ref_audio, text=ref_text)
136
+ for ref_text, ref_audio in zip(ref_texts, byte_audios)
137
+ ],
138
+ "reference_id": idstr,
139
+ "normalize": args.normalize,
140
+ "format": args.format,
141
+ "mp3_bitrate": args.mp3_bitrate,
142
+ "opus_bitrate": args.opus_bitrate,
143
+ "max_new_tokens": args.max_new_tokens,
144
+ "chunk_length": args.chunk_length,
145
+ "top_p": args.top_p,
146
+ "repetition_penalty": args.repetition_penalty,
147
+ "temperature": args.temperature,
148
+ "speaker": args.speaker,
149
+ "emotion": args.emotion,
150
+ "streaming": args.streaming,
151
+ }
152
+
153
+ pydantic_data = ServeTTSRequest(**data)
154
+
155
+ response = requests.post(
156
+ args.url,
157
+ data=ormsgpack.packb(pydantic_data, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
158
+ stream=args.streaming,
159
+ headers={
160
+ "authorization": "Bearer YOUR_API_KEY",
161
+ "content-type": "application/msgpack",
162
+ },
163
+ )
164
+
165
+ if response.status_code == 200:
166
+ if args.streaming:
167
+ p = pyaudio.PyAudio()
168
+ audio_format = pyaudio.paInt16 # Assuming 16-bit PCM format
169
+ stream = p.open(
170
+ format=audio_format, channels=args.channels, rate=args.rate, output=True
171
+ )
172
+
173
+ wf = wave.open(f"{args.output}.wav", "wb")
174
+ wf.setnchannels(args.channels)
175
+ wf.setsampwidth(p.get_sample_size(audio_format))
176
+ wf.setframerate(args.rate)
177
+
178
+ stream_stopped_flag = False
179
+
180
+ try:
181
+ for chunk in response.iter_content(chunk_size=1024):
182
+ if chunk:
183
+ stream.write(chunk)
184
+ wf.writeframesraw(chunk)
185
+ else:
186
+ if not stream_stopped_flag:
187
+ stream.stop_stream()
188
+ stream_stopped_flag = True
189
+ finally:
190
+ stream.close()
191
+ p.terminate()
192
+ wf.close()
193
+ else:
194
+ audio_content = response.content
195
+ audio_path = f"{args.output}.{args.format}"
196
+ with open(audio_path, "wb") as audio_file:
197
+ audio_file.write(audio_content)
198
+
199
+ audio = AudioSegment.from_file(audio_path, format=args.format)
200
+ if args.play:
201
+ play(audio)
202
+ print(f"Audio has been saved to '{audio_path}'.")
203
+ else:
204
+ print(f"Request failed with status code {response.status_code}")
205
+ print(response.json())
tools/sensevoice/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FunASR Command Line Interface
2
+
3
+ This tool provides a command-line interface for separating vocals from instrumental tracks, converting videos to audio, and performing speech-to-text transcription on the resulting audio files.
4
+
5
+ ## Requirements
6
+
7
+ - Python >= 3.10
8
+ - PyTorch <= 2.3.1
9
+ - ffmpeg, pydub, audio-separator[gpu].
10
+
11
+ ## Installation
12
+
13
+ Install the required packages:
14
+
15
+ ```bash
16
+ pip install -e .[stable]
17
+ ```
18
+
19
+ Make sure you have `ffmpeg` installed and available in your `PATH`.
20
+
21
+ ## Usage
22
+
23
+ ### Basic Usage
24
+
25
+ To run the tool with default settings:
26
+
27
+ ```bash
28
+ python tools/sensevoice/fun_asr.py --audio-dir <audio_directory> --save-dir <output_directory>
29
+ ```
30
+
31
+ ## Options
32
+
33
+ | Option | Description |
34
+ | :-----------------------: | :---------------------------------------------------------------------------: |
35
+ | --audio-dir | Directory containing audio or video files. |
36
+ | --save-dir | Directory to save processed audio files. |
37
+ | --device | Device to use for processing. Options: cuda (default) or cpu. |
38
+ | --language | Language of the transcription. Default is auto. |
39
+ | --max_single_segment_time | Maximum duration of a single audio segment in milliseconds. Default is 20000. |
40
+ | --punc | Enable punctuation prediction. |
41
+ | --denoise | Enable noise reduction (vocal separation). |
42
+
43
+ ## Example
44
+
45
+ To process audio files in the directory `path/to/audio` and save the output to `path/to/output`, with punctuation and noise reduction enabled:
46
+
47
+ ```bash
48
+ python tools/sensevoice/fun_asr.py --audio-dir path/to/audio --save-dir path/to/output --punc --denoise
49
+ ```
50
+
51
+ ## Additional Notes
52
+
53
+ - The tool supports `both audio and video files`. Videos will be converted to audio automatically.
54
+ - If the `--denoise` option is used, the tool will perform vocal separation to isolate the vocals from the instrumental tracks.
55
+ - The script will automatically create necessary directories in the `--save-dir`.
56
+
57
+ ## Troubleshooting
58
+
59
+ If you encounter any issues, make sure all dependencies are correctly installed and configured. For more detailed troubleshooting, refer to the documentation of each dependency.
tools/sensevoice/__init__.py ADDED
File without changes
tools/sensevoice/auto_model.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- encoding: utf-8 -*-
3
+ # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
4
+ # MIT License (https://opensource.org/licenses/MIT)
5
+
6
+ import copy
7
+ import json
8
+ import logging
9
+ import os.path
10
+ import random
11
+ import re
12
+ import string
13
+ import time
14
+
15
+ import numpy as np
16
+ import torch
17
+ from funasr.download.download_model_from_hub import download_model
18
+ from funasr.download.file import download_from_url
19
+ from funasr.register import tables
20
+ from funasr.train_utils.load_pretrained_model import load_pretrained_model
21
+ from funasr.train_utils.set_all_random_seed import set_all_random_seed
22
+ from funasr.utils import export_utils, misc
23
+ from funasr.utils.load_utils import load_audio_text_image_video, load_bytes
24
+ from funasr.utils.misc import deep_update
25
+ from funasr.utils.timestamp_tools import timestamp_sentence, timestamp_sentence_en
26
+ from tqdm import tqdm
27
+
28
+ from .vad_utils import merge_vad, slice_padding_audio_samples
29
+
30
+ try:
31
+ from funasr.models.campplus.cluster_backend import ClusterBackend
32
+ from funasr.models.campplus.utils import distribute_spk, postprocess, sv_chunk
33
+ except:
34
+ pass
35
+
36
+
37
+ def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
38
+ """ """
39
+ data_list = []
40
+ key_list = []
41
+ filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
42
+
43
+ chars = string.ascii_letters + string.digits
44
+ if isinstance(data_in, str):
45
+ if data_in.startswith("http://") or data_in.startswith("https://"): # url
46
+ data_in = download_from_url(data_in)
47
+
48
+ if isinstance(data_in, str) and os.path.exists(
49
+ data_in
50
+ ): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
51
+ _, file_extension = os.path.splitext(data_in)
52
+ file_extension = file_extension.lower()
53
+ if file_extension in filelist: # filelist: wav.scp, file.jsonl;text.txt;
54
+ with open(data_in, encoding="utf-8") as fin:
55
+ for line in fin:
56
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
57
+ if data_in.endswith(
58
+ ".jsonl"
59
+ ): # file.jsonl: json.dumps({"source": data})
60
+ lines = json.loads(line.strip())
61
+ data = lines["source"]
62
+ key = data["key"] if "key" in data else key
63
+ else: # filelist, wav.scp, text.txt: id \t data or data
64
+ lines = line.strip().split(maxsplit=1)
65
+ data = lines[1] if len(lines) > 1 else lines[0]
66
+ key = lines[0] if len(lines) > 1 else key
67
+
68
+ data_list.append(data)
69
+ key_list.append(key)
70
+ else:
71
+ if key is None:
72
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
73
+ key = misc.extract_filename_without_extension(data_in)
74
+ data_list = [data_in]
75
+ key_list = [key]
76
+ elif isinstance(data_in, (list, tuple)):
77
+ if data_type is not None and isinstance(
78
+ data_type, (list, tuple)
79
+ ): # mutiple inputs
80
+ data_list_tmp = []
81
+ for data_in_i, data_type_i in zip(data_in, data_type):
82
+ key_list, data_list_i = prepare_data_iterator(
83
+ data_in=data_in_i, data_type=data_type_i
84
+ )
85
+ data_list_tmp.append(data_list_i)
86
+ data_list = []
87
+ for item in zip(*data_list_tmp):
88
+ data_list.append(item)
89
+ else:
90
+ # [audio sample point, fbank, text]
91
+ data_list = data_in
92
+ key_list = []
93
+ for data_i in data_in:
94
+ if isinstance(data_i, str) and os.path.exists(data_i):
95
+ key = misc.extract_filename_without_extension(data_i)
96
+ else:
97
+ if key is None:
98
+ key = "rand_key_" + "".join(
99
+ random.choice(chars) for _ in range(13)
100
+ )
101
+ key_list.append(key)
102
+
103
+ else: # raw text; audio sample point, fbank; bytes
104
+ if isinstance(data_in, bytes): # audio bytes
105
+ data_in = load_bytes(data_in)
106
+ if key is None:
107
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
108
+ data_list = [data_in]
109
+ key_list = [key]
110
+
111
+ return key_list, data_list
112
+
113
+
114
+ class AutoModel:
115
+
116
+ def __init__(self, **kwargs):
117
+
118
+ try:
119
+ from funasr.utils.version_checker import check_for_update
120
+
121
+ print(
122
+ "Check update of funasr, and it would cost few times. You may disable it by set `disable_update=True` in AutoModel"
123
+ )
124
+ check_for_update(disable=kwargs.get("disable_update", False))
125
+ except:
126
+ pass
127
+
128
+ log_level = getattr(logging, kwargs.get("log_level", "INFO").upper())
129
+ logging.basicConfig(level=log_level)
130
+
131
+ model, kwargs = self.build_model(**kwargs)
132
+
133
+ # if vad_model is not None, build vad model else None
134
+ vad_model = kwargs.get("vad_model", None)
135
+ vad_kwargs = (
136
+ {} if kwargs.get("vad_kwargs", {}) is None else kwargs.get("vad_kwargs", {})
137
+ )
138
+ if vad_model is not None:
139
+ logging.info("Building VAD model.")
140
+ vad_kwargs["model"] = vad_model
141
+ vad_kwargs["model_revision"] = kwargs.get("vad_model_revision", "master")
142
+ vad_kwargs["device"] = kwargs["device"]
143
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
144
+
145
+ # if punc_model is not None, build punc model else None
146
+ punc_model = kwargs.get("punc_model", None)
147
+ punc_kwargs = (
148
+ {}
149
+ if kwargs.get("punc_kwargs", {}) is None
150
+ else kwargs.get("punc_kwargs", {})
151
+ )
152
+ if punc_model is not None:
153
+ logging.info("Building punc model.")
154
+ punc_kwargs["model"] = punc_model
155
+ punc_kwargs["model_revision"] = kwargs.get("punc_model_revision", "master")
156
+ punc_kwargs["device"] = kwargs["device"]
157
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
158
+
159
+ # if spk_model is not None, build spk model else None
160
+ spk_model = kwargs.get("spk_model", None)
161
+ spk_kwargs = (
162
+ {} if kwargs.get("spk_kwargs", {}) is None else kwargs.get("spk_kwargs", {})
163
+ )
164
+ if spk_model is not None:
165
+ logging.info("Building SPK model.")
166
+ spk_kwargs["model"] = spk_model
167
+ spk_kwargs["model_revision"] = kwargs.get("spk_model_revision", "master")
168
+ spk_kwargs["device"] = kwargs["device"]
169
+ spk_model, spk_kwargs = self.build_model(**spk_kwargs)
170
+ self.cb_model = ClusterBackend().to(kwargs["device"])
171
+ spk_mode = kwargs.get("spk_mode", "punc_segment")
172
+ if spk_mode not in ["default", "vad_segment", "punc_segment"]:
173
+ logging.error(
174
+ "spk_mode should be one of default, vad_segment and punc_segment."
175
+ )
176
+ self.spk_mode = spk_mode
177
+
178
+ self.kwargs = kwargs
179
+ self.model = model
180
+ self.vad_model = vad_model
181
+ self.vad_kwargs = vad_kwargs
182
+ self.punc_model = punc_model
183
+ self.punc_kwargs = punc_kwargs
184
+ self.spk_model = spk_model
185
+ self.spk_kwargs = spk_kwargs
186
+ self.model_path = kwargs.get("model_path")
187
+
188
+ @staticmethod
189
+ def build_model(**kwargs):
190
+ assert "model" in kwargs
191
+ if "model_conf" not in kwargs:
192
+ logging.info(
193
+ "download models from model hub: {}".format(kwargs.get("hub", "ms"))
194
+ )
195
+ kwargs = download_model(**kwargs)
196
+
197
+ set_all_random_seed(kwargs.get("seed", 0))
198
+
199
+ device = kwargs.get("device", "cuda")
200
+ if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
201
+ device = "cpu"
202
+ kwargs["batch_size"] = 1
203
+ kwargs["device"] = device
204
+
205
+ torch.set_num_threads(kwargs.get("ncpu", 4))
206
+
207
+ # build tokenizer
208
+ tokenizer = kwargs.get("tokenizer", None)
209
+ if tokenizer is not None:
210
+ tokenizer_class = tables.tokenizer_classes.get(tokenizer)
211
+ tokenizer = tokenizer_class(**kwargs.get("tokenizer_conf", {}))
212
+ kwargs["token_list"] = (
213
+ tokenizer.token_list if hasattr(tokenizer, "token_list") else None
214
+ )
215
+ kwargs["token_list"] = (
216
+ tokenizer.get_vocab()
217
+ if hasattr(tokenizer, "get_vocab")
218
+ else kwargs["token_list"]
219
+ )
220
+ vocab_size = (
221
+ len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
222
+ )
223
+ if vocab_size == -1 and hasattr(tokenizer, "get_vocab_size"):
224
+ vocab_size = tokenizer.get_vocab_size()
225
+ else:
226
+ vocab_size = -1
227
+ kwargs["tokenizer"] = tokenizer
228
+
229
+ # build frontend
230
+ frontend = kwargs.get("frontend", None)
231
+ kwargs["input_size"] = None
232
+ if frontend is not None:
233
+ frontend_class = tables.frontend_classes.get(frontend)
234
+ frontend = frontend_class(**kwargs.get("frontend_conf", {}))
235
+ kwargs["input_size"] = (
236
+ frontend.output_size() if hasattr(frontend, "output_size") else None
237
+ )
238
+ kwargs["frontend"] = frontend
239
+ # build model
240
+ model_class = tables.model_classes.get(kwargs["model"])
241
+ assert model_class is not None, f'{kwargs["model"]} is not registered'
242
+ model_conf = {}
243
+ deep_update(model_conf, kwargs.get("model_conf", {}))
244
+ deep_update(model_conf, kwargs)
245
+ model = model_class(**model_conf, vocab_size=vocab_size)
246
+
247
+ # init_param
248
+ init_param = kwargs.get("init_param", None)
249
+ if init_param is not None:
250
+ if os.path.exists(init_param):
251
+ logging.info(f"Loading pretrained params from {init_param}")
252
+ load_pretrained_model(
253
+ model=model,
254
+ path=init_param,
255
+ ignore_init_mismatch=kwargs.get("ignore_init_mismatch", True),
256
+ oss_bucket=kwargs.get("oss_bucket", None),
257
+ scope_map=kwargs.get("scope_map", []),
258
+ excludes=kwargs.get("excludes", None),
259
+ )
260
+ else:
261
+ print(f"error, init_param does not exist!: {init_param}")
262
+
263
+ # fp16
264
+ if kwargs.get("fp16", False):
265
+ model.to(torch.float16)
266
+ elif kwargs.get("bf16", False):
267
+ model.to(torch.bfloat16)
268
+ model.to(device)
269
+
270
+ if not kwargs.get("disable_log", True):
271
+ tables.print()
272
+
273
+ return model, kwargs
274
+
275
+ def __call__(self, *args, **cfg):
276
+ kwargs = self.kwargs
277
+ deep_update(kwargs, cfg)
278
+ res = self.model(*args, kwargs)
279
+ return res
280
+
281
+ def generate(self, input, input_len=None, **cfg):
282
+ if self.vad_model is None:
283
+ return self.inference(input, input_len=input_len, **cfg)
284
+
285
+ else:
286
+ return self.inference_with_vad(input, input_len=input_len, **cfg)
287
+
288
+ def inference(
289
+ self, input, input_len=None, model=None, kwargs=None, key=None, **cfg
290
+ ):
291
+ kwargs = self.kwargs if kwargs is None else kwargs
292
+ if "cache" in kwargs:
293
+ kwargs.pop("cache")
294
+ deep_update(kwargs, cfg)
295
+ model = self.model if model is None else model
296
+ model.eval()
297
+
298
+ batch_size = kwargs.get("batch_size", 1)
299
+ # if kwargs.get("device", "cpu") == "cpu":
300
+ # batch_size = 1
301
+
302
+ key_list, data_list = prepare_data_iterator(
303
+ input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key
304
+ )
305
+
306
+ speed_stats = {}
307
+ asr_result_list = []
308
+ num_samples = len(data_list)
309
+ disable_pbar = self.kwargs.get("disable_pbar", False)
310
+ pbar = (
311
+ tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
312
+ if not disable_pbar
313
+ else None
314
+ )
315
+ time_speech_total = 0.0
316
+ time_escape_total = 0.0
317
+ for beg_idx in range(0, num_samples, batch_size):
318
+ end_idx = min(num_samples, beg_idx + batch_size)
319
+ data_batch = data_list[beg_idx:end_idx]
320
+ key_batch = key_list[beg_idx:end_idx]
321
+ batch = {"data_in": data_batch, "key": key_batch}
322
+
323
+ if (end_idx - beg_idx) == 1 and kwargs.get(
324
+ "data_type", None
325
+ ) == "fbank": # fbank
326
+ batch["data_in"] = data_batch[0]
327
+ batch["data_lengths"] = input_len
328
+
329
+ time1 = time.perf_counter()
330
+ with torch.no_grad():
331
+ res = model.inference(**batch, **kwargs)
332
+ if isinstance(res, (list, tuple)):
333
+ results = res[0] if len(res) > 0 else [{"text": ""}]
334
+ meta_data = res[1] if len(res) > 1 else {}
335
+ time2 = time.perf_counter()
336
+
337
+ asr_result_list.extend(results)
338
+
339
+ # batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
340
+ batch_data_time = meta_data.get("batch_data_time", -1)
341
+ time_escape = time2 - time1
342
+ speed_stats["load_data"] = meta_data.get("load_data", 0.0)
343
+ speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
344
+ speed_stats["forward"] = f"{time_escape:0.3f}"
345
+ speed_stats["batch_size"] = f"{len(results)}"
346
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
347
+ description = f"{speed_stats}, "
348
+ if pbar:
349
+ pbar.update(end_idx - beg_idx)
350
+ pbar.set_description(description)
351
+ time_speech_total += batch_data_time
352
+ time_escape_total += time_escape
353
+
354
+ if pbar:
355
+ # pbar.update(1)
356
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
357
+ torch.cuda.empty_cache()
358
+ return asr_result_list
359
+
360
+ def vad(self, input, input_len=None, **cfg):
361
+ kwargs = self.kwargs
362
+ # step.1: compute the vad model
363
+ deep_update(self.vad_kwargs, cfg)
364
+ beg_vad = time.time()
365
+ res = self.inference(
366
+ input,
367
+ input_len=input_len,
368
+ model=self.vad_model,
369
+ kwargs=self.vad_kwargs,
370
+ **cfg,
371
+ )
372
+ end_vad = time.time()
373
+ # FIX(gcf): concat the vad clips for sense vocie model for better aed
374
+ if cfg.get("merge_vad", False):
375
+ for i in range(len(res)):
376
+ res[i]["value"] = merge_vad(
377
+ res[i]["value"], kwargs.get("merge_length_s", 15) * 1000
378
+ )
379
+ elapsed = end_vad - beg_vad
380
+ return elapsed, res
381
+
382
+ def inference_with_vadres(self, input, vad_res, input_len=None, **cfg):
383
+
384
+ kwargs = self.kwargs
385
+
386
+ # step.2 compute asr model
387
+ model = self.model
388
+ deep_update(kwargs, cfg)
389
+ batch_size = max(int(kwargs.get("batch_size_s", 300)) * 1000, 1)
390
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60)) * 1000
391
+ kwargs["batch_size"] = batch_size
392
+
393
+ key_list, data_list = prepare_data_iterator(
394
+ input, input_len=input_len, data_type=kwargs.get("data_type", None)
395
+ )
396
+ results_ret_list = []
397
+ time_speech_total_all_samples = 1e-6
398
+
399
+ beg_total = time.time()
400
+ pbar_total = (
401
+ tqdm(colour="red", total=len(vad_res), dynamic_ncols=True)
402
+ if not kwargs.get("disable_pbar", False)
403
+ else None
404
+ )
405
+
406
+ for i in range(len(vad_res)):
407
+ key = vad_res[i]["key"]
408
+ vadsegments = vad_res[i]["value"]
409
+ input_i = data_list[i]
410
+ fs = kwargs["frontend"].fs if hasattr(kwargs["frontend"], "fs") else 16000
411
+ speech = load_audio_text_image_video(
412
+ input_i, fs=fs, audio_fs=kwargs.get("fs", 16000)
413
+ )
414
+ speech_lengths = len(speech)
415
+ n = len(vadsegments)
416
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
417
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
418
+ results_sorted = []
419
+
420
+ if not len(sorted_data):
421
+ results_ret_list.append({"key": key, "text": "", "timestamp": []})
422
+ logging.info("decoding, utt: {}, empty speech".format(key))
423
+ continue
424
+
425
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
426
+ batch_size = max(
427
+ batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]
428
+ )
429
+
430
+ if kwargs["device"] == "cpu":
431
+ batch_size = 0
432
+
433
+ beg_idx = 0
434
+ beg_asr_total = time.time()
435
+ time_speech_total_per_sample = speech_lengths / 16000
436
+ time_speech_total_all_samples += time_speech_total_per_sample
437
+
438
+ # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
439
+
440
+ all_segments = []
441
+ max_len_in_batch = 0
442
+ end_idx = 1
443
+
444
+ for j, _ in enumerate(range(0, n)):
445
+ # pbar_sample.update(1)
446
+ sample_length = sorted_data[j][0][1] - sorted_data[j][0][0]
447
+ potential_batch_length = max(max_len_in_batch, sample_length) * (
448
+ j + 1 - beg_idx
449
+ )
450
+ # batch_size_ms_cum += sorted_data[j][0][1] - sorted_data[j][0][0]
451
+ if (
452
+ j < n - 1
453
+ and sample_length < batch_size_threshold_ms
454
+ and potential_batch_length < batch_size
455
+ ):
456
+ max_len_in_batch = max(max_len_in_batch, sample_length)
457
+ end_idx += 1
458
+ continue
459
+
460
+ speech_j, speech_lengths_j, intervals = slice_padding_audio_samples(
461
+ speech, speech_lengths, sorted_data[beg_idx:end_idx]
462
+ )
463
+ results = self.inference(
464
+ speech_j, input_len=None, model=model, kwargs=kwargs, **cfg
465
+ )
466
+
467
+ for _b in range(len(speech_j)):
468
+ results[_b]["interval"] = intervals[_b]
469
+
470
+ if self.spk_model is not None:
471
+ # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
472
+ for _b in range(len(speech_j)):
473
+ vad_segments = [
474
+ [
475
+ sorted_data[beg_idx:end_idx][_b][0][0] / 1000.0,
476
+ sorted_data[beg_idx:end_idx][_b][0][1] / 1000.0,
477
+ np.array(speech_j[_b]),
478
+ ]
479
+ ]
480
+ segments = sv_chunk(vad_segments)
481
+ all_segments.extend(segments)
482
+ speech_b = [i[2] for i in segments]
483
+ spk_res = self.inference(
484
+ speech_b,
485
+ input_len=None,
486
+ model=self.spk_model,
487
+ kwargs=kwargs,
488
+ **cfg,
489
+ )
490
+ results[_b]["spk_embedding"] = spk_res[0]["spk_embedding"]
491
+
492
+ beg_idx = end_idx
493
+ end_idx += 1
494
+ max_len_in_batch = sample_length
495
+ if len(results) < 1:
496
+ continue
497
+ results_sorted.extend(results)
498
+
499
+ # end_asr_total = time.time()
500
+ # time_escape_total_per_sample = end_asr_total - beg_asr_total
501
+ # pbar_sample.update(1)
502
+ # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
503
+ # f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
504
+ # f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
505
+
506
+ restored_data = [0] * n
507
+ for j in range(n):
508
+ index = sorted_data[j][1]
509
+ cur = results_sorted[j]
510
+ pattern = r"<\|([^|]+)\|>"
511
+ emotion_string = re.findall(pattern, cur["text"])
512
+ cur["text"] = re.sub(pattern, "", cur["text"])
513
+ cur["emo"] = "".join([f"<|{t}|>" for t in emotion_string])
514
+ if self.punc_model is not None and len(cur["text"].strip()) > 0:
515
+ deep_update(self.punc_kwargs, cfg)
516
+ punc_res = self.inference(
517
+ cur["text"],
518
+ model=self.punc_model,
519
+ kwargs=self.punc_kwargs,
520
+ **cfg,
521
+ )
522
+ cur["text"] = punc_res[0]["text"]
523
+
524
+ restored_data[index] = cur
525
+
526
+ end_asr_total = time.time()
527
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
528
+ if pbar_total:
529
+ pbar_total.update(1)
530
+ pbar_total.set_description(
531
+ f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
532
+ f"time_speech: {time_speech_total_per_sample: 0.3f}, "
533
+ f"time_escape: {time_escape_total_per_sample:0.3f}"
534
+ )
535
+
536
+ # end_total = time.time()
537
+ # time_escape_total_all_samples = end_total - beg_total
538
+ # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
539
+ # f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
540
+ # f"time_escape_all: {time_escape_total_all_samples:0.3f}")
541
+ return restored_data
542
+
543
+ def export(self, input=None, **cfg):
544
+ """
545
+
546
+ :param input:
547
+ :param type:
548
+ :param quantize:
549
+ :param fallback_num:
550
+ :param calib_num:
551
+ :param opset_version:
552
+ :param cfg:
553
+ :return:
554
+ """
555
+
556
+ device = cfg.get("device", "cpu")
557
+ model = self.model.to(device=device)
558
+ kwargs = self.kwargs
559
+ deep_update(kwargs, cfg)
560
+ kwargs["device"] = device
561
+ del kwargs["model"]
562
+ model.eval()
563
+
564
+ type = kwargs.get("type", "onnx")
565
+
566
+ key_list, data_list = prepare_data_iterator(
567
+ input, input_len=None, data_type=kwargs.get("data_type", None), key=None
568
+ )
569
+
570
+ with torch.no_grad():
571
+ export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
572
+
573
+ return export_dir
tools/sensevoice/fun_asr.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import re
4
+
5
+ from audio_separator.separator import Separator
6
+
7
+ os.environ["MODELSCOPE_CACHE"] = "./.cache/funasr"
8
+ os.environ["UVR5_CACHE"] = "./.cache/uvr5-models"
9
+ import json
10
+ import subprocess
11
+ from pathlib import Path
12
+
13
+ import click
14
+ import torch
15
+ from loguru import logger
16
+ from pydub import AudioSegment
17
+ from silero_vad import get_speech_timestamps, load_silero_vad, read_audio
18
+ from tqdm import tqdm
19
+
20
+ from tools.file import AUDIO_EXTENSIONS, VIDEO_EXTENSIONS, list_files
21
+ from tools.sensevoice.auto_model import AutoModel
22
+
23
+
24
+ def uvr5_cli(
25
+ audio_dir: Path,
26
+ output_folder: Path,
27
+ audio_files: list[Path] | None = None,
28
+ output_format: str = "flac",
29
+ model: str = "BS-Roformer-Viperx-1297.ckpt",
30
+ ):
31
+ # ["BS-Roformer-Viperx-1297.ckpt", "BS-Roformer-Viperx-1296.ckpt", "BS-Roformer-Viperx-1053.ckpt", "Mel-Roformer-Viperx-1143.ckpt"]
32
+ sepr = Separator(
33
+ model_file_dir=os.environ["UVR5_CACHE"],
34
+ output_dir=output_folder,
35
+ output_format=output_format,
36
+ )
37
+ dictmodel = {
38
+ "BS-Roformer-Viperx-1297.ckpt": "model_bs_roformer_ep_317_sdr_12.9755.ckpt",
39
+ "BS-Roformer-Viperx-1296.ckpt": "model_bs_roformer_ep_368_sdr_12.9628.ckpt",
40
+ "BS-Roformer-Viperx-1053.ckpt": "model_bs_roformer_ep_937_sdr_10.5309.ckpt",
41
+ "Mel-Roformer-Viperx-1143.ckpt": "model_mel_band_roformer_ep_3005_sdr_11.4360.ckpt",
42
+ }
43
+ roformer_model = dictmodel[model]
44
+ sepr.load_model(roformer_model)
45
+ if audio_files is None:
46
+ audio_files = list_files(
47
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
48
+ )
49
+ total_files = len(audio_files)
50
+
51
+ print(f"{total_files} audio files found")
52
+
53
+ res = []
54
+ for audio in tqdm(audio_files, desc="Denoising: "):
55
+ file_path = str(audio_dir / audio)
56
+ sep_out = sepr.separate(file_path)
57
+ if isinstance(sep_out, str):
58
+ res.append(sep_out)
59
+ elif isinstance(sep_out, list):
60
+ res.extend(sep_out)
61
+ del sepr
62
+ gc.collect()
63
+ if torch.cuda.is_available():
64
+ torch.cuda.empty_cache()
65
+
66
+ return res, roformer_model
67
+
68
+
69
+ def get_sample_rate(media_path: Path):
70
+ result = subprocess.run(
71
+ [
72
+ "ffprobe",
73
+ "-v",
74
+ "quiet",
75
+ "-print_format",
76
+ "json",
77
+ "-show_streams",
78
+ str(media_path),
79
+ ],
80
+ capture_output=True,
81
+ text=True,
82
+ check=True,
83
+ )
84
+ media_info = json.loads(result.stdout)
85
+ for stream in media_info.get("streams", []):
86
+ if stream.get("codec_type") == "audio":
87
+ return stream.get("sample_rate")
88
+ return "44100" # Default sample rate if not found
89
+
90
+
91
+ def convert_to_mono(src_path: Path, out_path: Path, out_fmt: str = "wav"):
92
+ sr = get_sample_rate(src_path)
93
+ out_path.parent.mkdir(parents=True, exist_ok=True)
94
+ if src_path.resolve() == out_path.resolve():
95
+ output = str(out_path.with_stem(out_path.stem + f"_{sr}"))
96
+ else:
97
+ output = str(out_path)
98
+ subprocess.run(
99
+ [
100
+ "ffmpeg",
101
+ "-loglevel",
102
+ "error",
103
+ "-i",
104
+ str(src_path),
105
+ "-acodec",
106
+ "pcm_s16le" if out_fmt == "wav" else "flac",
107
+ "-ar",
108
+ sr,
109
+ "-ac",
110
+ "1",
111
+ "-y",
112
+ output,
113
+ ],
114
+ check=True,
115
+ )
116
+ return out_path
117
+
118
+
119
+ def convert_video_to_audio(video_path: Path, audio_dir: Path):
120
+ cur_dir = audio_dir / video_path.relative_to(audio_dir).parent
121
+ vocals = [
122
+ p
123
+ for p in cur_dir.glob(f"{video_path.stem}_(Vocals)*.*")
124
+ if p.suffix in AUDIO_EXTENSIONS
125
+ ]
126
+ if len(vocals) > 0:
127
+ return vocals[0]
128
+ audio_path = cur_dir / f"{video_path.stem}.wav"
129
+ convert_to_mono(video_path, audio_path)
130
+ return audio_path
131
+
132
+
133
+ @click.command()
134
+ @click.option("--audio-dir", required=True, help="Directory containing audio files")
135
+ @click.option(
136
+ "--save-dir", required=True, help="Directory to save processed audio files"
137
+ )
138
+ @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
139
+ @click.option("--language", default="auto", help="Language of the transcription")
140
+ @click.option(
141
+ "--max_single_segment_time",
142
+ default=20000,
143
+ type=int,
144
+ help="Maximum of Output single audio duration(ms)",
145
+ )
146
+ @click.option("--fsmn-vad/--silero-vad", default=False)
147
+ @click.option("--punc/--no-punc", default=False)
148
+ @click.option("--denoise/--no-denoise", default=False)
149
+ @click.option("--save_emo/--no_save_emo", default=False)
150
+ def main(
151
+ audio_dir: str,
152
+ save_dir: str,
153
+ device: str,
154
+ language: str,
155
+ max_single_segment_time: int,
156
+ fsmn_vad: bool,
157
+ punc: bool,
158
+ denoise: bool,
159
+ save_emo: bool,
160
+ ):
161
+
162
+ audios_path = Path(audio_dir)
163
+ save_path = Path(save_dir)
164
+ save_path.mkdir(parents=True, exist_ok=True)
165
+
166
+ video_files = list_files(
167
+ path=audio_dir, extensions=VIDEO_EXTENSIONS, recursive=True
168
+ )
169
+ v2a_files = [convert_video_to_audio(p, audio_dir) for p in video_files]
170
+
171
+ if denoise:
172
+ VOCAL = "_(Vocals)"
173
+ original_files = [
174
+ p
175
+ for p in audios_path.glob("**/*")
176
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL not in p.stem
177
+ ]
178
+
179
+ _, cur_model = uvr5_cli(
180
+ audio_dir=audio_dir, output_folder=audio_dir, audio_files=original_files
181
+ )
182
+ need_remove = [p for p in audios_path.glob("**/*(Instrumental)*")]
183
+ need_remove.extend(original_files)
184
+ for _ in need_remove:
185
+ _.unlink()
186
+ vocal_files = [
187
+ p
188
+ for p in audios_path.glob("**/*")
189
+ if p.suffix in AUDIO_EXTENSIONS and VOCAL in p.stem
190
+ ]
191
+ for f in vocal_files:
192
+ fn, ext = f.stem, f.suffix
193
+
194
+ v_pos = fn.find(VOCAL + "_" + cur_model.split(".")[0])
195
+ if v_pos != -1:
196
+ new_fn = fn[: v_pos + len(VOCAL)]
197
+ new_f = f.with_name(new_fn + ext)
198
+ f = f.rename(new_f)
199
+ convert_to_mono(f, f, "flac")
200
+ f.unlink()
201
+
202
+ audio_files = list_files(
203
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
204
+ )
205
+
206
+ logger.info("Loading / Downloading Funasr model...")
207
+
208
+ model_dir = "iic/SenseVoiceSmall"
209
+
210
+ vad_model = "fsmn-vad" if fsmn_vad else None
211
+ vad_kwargs = {"max_single_segment_time": max_single_segment_time}
212
+ punc_model = "ct-punc" if punc else None
213
+
214
+ manager = AutoModel(
215
+ model=model_dir,
216
+ trust_remote_code=False,
217
+ vad_model=vad_model,
218
+ vad_kwargs=vad_kwargs,
219
+ punc_model=punc_model,
220
+ device=device,
221
+ )
222
+
223
+ if not fsmn_vad and vad_model is None:
224
+ vad_model = load_silero_vad()
225
+
226
+ logger.info("Model loaded.")
227
+
228
+ pattern = re.compile(r"_\d{3}\.")
229
+
230
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
231
+
232
+ if pattern.search(file_path.name):
233
+ # logger.info(f"Skipping {file_path} as it has already been processed.")
234
+ continue
235
+
236
+ file_stem = file_path.stem
237
+ file_suffix = file_path.suffix
238
+
239
+ rel_path = Path(file_path).relative_to(audio_dir)
240
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
241
+
242
+ audio = AudioSegment.from_file(file_path)
243
+
244
+ cfg = dict(
245
+ cache={},
246
+ language=language, # "zh", "en", "yue", "ja", "ko", "nospeech"
247
+ use_itn=False,
248
+ batch_size_s=60,
249
+ )
250
+
251
+ if fsmn_vad:
252
+ elapsed, vad_res = manager.vad(input=str(file_path), **cfg)
253
+ else:
254
+ wav = read_audio(
255
+ str(file_path)
256
+ ) # backend (sox, soundfile, or ffmpeg) required!
257
+ audio_key = file_path.stem
258
+ audio_val = []
259
+ speech_timestamps = get_speech_timestamps(
260
+ wav,
261
+ vad_model,
262
+ max_speech_duration_s=max_single_segment_time // 1000,
263
+ return_seconds=True,
264
+ )
265
+
266
+ audio_val = [
267
+ [int(timestamp["start"] * 1000), int(timestamp["end"] * 1000)]
268
+ for timestamp in speech_timestamps
269
+ ]
270
+ vad_res = []
271
+ vad_res.append(dict(key=audio_key, value=audio_val))
272
+
273
+ res = manager.inference_with_vadres(
274
+ input=str(file_path), vad_res=vad_res, **cfg
275
+ )
276
+
277
+ for i, info in enumerate(res):
278
+ [start_ms, end_ms] = info["interval"]
279
+ text = info["text"]
280
+ emo = info["emo"]
281
+ sliced_audio = audio[start_ms:end_ms]
282
+ audio_save_path = (
283
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}{file_suffix}"
284
+ )
285
+ sliced_audio.export(audio_save_path, format=file_suffix[1:])
286
+ print(f"Exported {audio_save_path}: {text}")
287
+
288
+ transcript_save_path = (
289
+ save_path / rel_path.parent / f"{file_stem}_{i:03d}.lab"
290
+ )
291
+ with open(
292
+ transcript_save_path,
293
+ "w",
294
+ encoding="utf-8",
295
+ ) as f:
296
+ f.write(text)
297
+
298
+ if save_emo:
299
+ emo_save_path = save_path / rel_path.parent / f"{file_stem}_{i:03d}.emo"
300
+ with open(
301
+ emo_save_path,
302
+ "w",
303
+ encoding="utf-8",
304
+ ) as f:
305
+ f.write(emo)
306
+
307
+ if audios_path.resolve() == save_path.resolve():
308
+ file_path.unlink()
309
+
310
+
311
+ if __name__ == "__main__":
312
+ main()
313
+ exit(0)
314
+ from funasr.utils.postprocess_utils import rich_transcription_postprocess
315
+
316
+ # Load the audio file
317
+ audio_path = Path(r"D:\PythonProject\ok\1_output_(Vocals).wav")
318
+ model_dir = "iic/SenseVoiceSmall"
319
+ m, kwargs = SenseVoiceSmall.from_pretrained(model=model_dir, device="cuda:0")
320
+ m.eval()
321
+
322
+ res = m.inference(
323
+ data_in=f"{kwargs['model_path']}/example/zh.mp3",
324
+ language="auto", # "zh", "en", "yue", "ja", "ko", "nospeech"
325
+ use_itn=False,
326
+ ban_emo_unk=False,
327
+ **kwargs,
328
+ )
329
+
330
+ print(res)
331
+ text = rich_transcription_postprocess(res[0][0]["text"])
332
+ print(text)
tools/sensevoice/vad_utils.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn.utils.rnn import pad_sequence
3
+
4
+
5
+ def slice_padding_fbank(speech, speech_lengths, vad_segments):
6
+ speech_list = []
7
+ speech_lengths_list = []
8
+ for i, segment in enumerate(vad_segments):
9
+
10
+ bed_idx = int(segment[0][0] * 16)
11
+ end_idx = min(int(segment[0][1] * 16), speech_lengths[0])
12
+ speech_i = speech[0, bed_idx:end_idx]
13
+ speech_lengths_i = end_idx - bed_idx
14
+ speech_list.append(speech_i)
15
+ speech_lengths_list.append(speech_lengths_i)
16
+ feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
17
+ speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
18
+ return feats_pad, speech_lengths_pad
19
+
20
+
21
+ def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
22
+ speech_list = []
23
+ speech_lengths_list = []
24
+ intervals = []
25
+ for i, segment in enumerate(vad_segments):
26
+ bed_idx = int(segment[0][0] * 16)
27
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
28
+ speech_i = speech[bed_idx:end_idx]
29
+ speech_lengths_i = end_idx - bed_idx
30
+ speech_list.append(speech_i)
31
+ speech_lengths_list.append(speech_lengths_i)
32
+ intervals.append([bed_idx // 16, end_idx // 16])
33
+
34
+ return speech_list, speech_lengths_list, intervals
35
+
36
+
37
+ def merge_vad(vad_result, max_length=15000, min_length=0):
38
+ new_result = []
39
+ if len(vad_result) <= 1:
40
+ return vad_result
41
+ time_step = [t[0] for t in vad_result] + [t[1] for t in vad_result]
42
+ time_step = sorted(list(set(time_step)))
43
+ if len(time_step) == 0:
44
+ return []
45
+ bg = 0
46
+ for i in range(len(time_step) - 1):
47
+ time = time_step[i]
48
+ if time_step[i + 1] - bg < max_length:
49
+ continue
50
+ if time - bg > min_length:
51
+ new_result.append([bg, time])
52
+ # if time - bg < max_length * 1.5:
53
+ # new_result.append([bg, time])
54
+ # else:
55
+ # split_num = int(time - bg) // max_length + 1
56
+ # spl_l = int(time - bg) // split_num
57
+ # for j in range(split_num):
58
+ # new_result.append([bg + j * spl_l, bg + (j + 1) * spl_l])
59
+ bg = time
60
+ new_result.append([bg, time_step[-1]])
61
+ return new_result
tools/smart_pad.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from multiprocessing import Pool
3
+ from pathlib import Path
4
+
5
+ import click
6
+ import librosa
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+ from tqdm import tqdm
10
+
11
+ from tools.file import AUDIO_EXTENSIONS, list_files
12
+
13
+ threshold = 10 ** (-50 / 20.0)
14
+
15
+
16
+ def process(file):
17
+ waveform, sample_rate = torchaudio.load(str(file), backend="sox")
18
+ loudness = librosa.feature.rms(
19
+ y=waveform.numpy().squeeze(), frame_length=2048, hop_length=512, center=True
20
+ )[0]
21
+ for i in range(len(loudness) - 1, 0, -1):
22
+ if loudness[i] > threshold:
23
+ break
24
+
25
+ silent_time = (len(loudness) - i) * 512 / sample_rate
26
+
27
+ if silent_time <= 0.3:
28
+ random_time = random.uniform(0.3, 0.7)
29
+ waveform = F.pad(
30
+ waveform, (0, int(random_time * sample_rate)), mode="constant", value=0
31
+ )
32
+
33
+ torchaudio.save(uri=str(file), src=waveform, sample_rate=sample_rate)
34
+
35
+
36
+ @click.command()
37
+ @click.argument("source", type=Path)
38
+ @click.option("--num-workers", type=int, default=12)
39
+ def main(source, num_workers):
40
+ files = list(list_files(source, AUDIO_EXTENSIONS, recursive=True))
41
+
42
+ with Pool(num_workers) as p:
43
+ list(tqdm(p.imap_unordered(process, files), total=len(files)))
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
tools/vqgan/create_train_split.py CHANGED
@@ -7,7 +7,7 @@ from loguru import logger
7
  from pydub import AudioSegment
8
  from tqdm import tqdm
9
 
10
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
11
 
12
 
13
  @click.command()
 
7
  from pydub import AudioSegment
8
  from tqdm import tqdm
9
 
10
+ from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
11
 
12
 
13
  @click.command()
tools/vqgan/extract_vq.py CHANGED
@@ -17,7 +17,7 @@ from lightning import LightningModule
17
  from loguru import logger
18
  from omegaconf import OmegaConf
19
 
20
- from fish_speech.utils.file import AUDIO_EXTENSIONS, list_files, load_filelist
21
 
22
  # register eval resolver
23
  OmegaConf.register_new_resolver("eval", eval)
@@ -42,7 +42,7 @@ logger.add(sys.stderr, format=logger_format)
42
  @lru_cache(maxsize=1)
43
  def get_model(
44
  config_name: str = "firefly_gan_vq",
45
- checkpoint_path: str = "checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
46
  device: str | torch.device = "cuda",
47
  ):
48
  with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
@@ -133,7 +133,7 @@ def process_batch(files: list[Path], model) -> float:
133
  @click.option("--config-name", default="firefly_gan_vq")
134
  @click.option(
135
  "--checkpoint-path",
136
- default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
137
  )
138
  @click.option("--batch-size", default=64)
139
  @click.option("--filelist", default=None, type=Path)
 
17
  from loguru import logger
18
  from omegaconf import OmegaConf
19
 
20
+ from tools.file import AUDIO_EXTENSIONS, list_files, load_filelist
21
 
22
  # register eval resolver
23
  OmegaConf.register_new_resolver("eval", eval)
 
42
  @lru_cache(maxsize=1)
43
  def get_model(
44
  config_name: str = "firefly_gan_vq",
45
+ checkpoint_path: str = "checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
46
  device: str | torch.device = "cuda",
47
  ):
48
  with initialize(version_base="1.3", config_path="../../fish_speech/configs"):
 
133
  @click.option("--config-name", default="firefly_gan_vq")
134
  @click.option(
135
  "--checkpoint-path",
136
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
137
  )
138
  @click.option("--batch-size", default=64)
139
  @click.option("--filelist", default=None, type=Path)
tools/vqgan/inference.py CHANGED
@@ -11,7 +11,7 @@ from hydra.utils import instantiate
11
  from loguru import logger
12
  from omegaconf import OmegaConf
13
 
14
- from fish_speech.utils.file import AUDIO_EXTENSIONS
15
 
16
  # register eval resolver
17
  OmegaConf.register_new_resolver("eval", eval)
@@ -59,7 +59,7 @@ def load_model(config_name, checkpoint_path, device="cuda"):
59
  @click.option("--config-name", default="firefly_gan_vq")
60
  @click.option(
61
  "--checkpoint-path",
62
- default="checkpoints/fish-speech-1.2-sft/firefly-gan-vq-fsq-4x1024-42hz-generator.pth",
63
  )
64
  @click.option(
65
  "--device",
@@ -103,7 +103,9 @@ def main(input_path, output_path, config_name, checkpoint_path, device):
103
 
104
  # Restore
105
  feature_lengths = torch.tensor([indices.shape[1]], device=device)
106
- fake_audios = model.decode(indices=indices[None], feature_lengths=feature_lengths)
 
 
107
  audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
108
 
109
  logger.info(
 
11
  from loguru import logger
12
  from omegaconf import OmegaConf
13
 
14
+ from tools.file import AUDIO_EXTENSIONS
15
 
16
  # register eval resolver
17
  OmegaConf.register_new_resolver("eval", eval)
 
59
  @click.option("--config-name", default="firefly_gan_vq")
60
  @click.option(
61
  "--checkpoint-path",
62
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
63
  )
64
  @click.option(
65
  "--device",
 
103
 
104
  # Restore
105
  feature_lengths = torch.tensor([indices.shape[1]], device=device)
106
+ fake_audios, _ = model.decode(
107
+ indices=indices[None], feature_lengths=feature_lengths
108
+ )
109
  audio_time = fake_audios.shape[-1] / model.spec_transform.sample_rate
110
 
111
  logger.info(
tools/webui.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import html
3
+ import io
4
+ import os
5
+ import queue
6
+ import wave
7
+ from argparse import ArgumentParser
8
+ from functools import partial
9
+ from pathlib import Path
10
+
11
+ import gradio as gr
12
+ import librosa
13
+ import numpy as np
14
+ import pyrootutils
15
+ import torch
16
+ from loguru import logger
17
+ from transformers import AutoTokenizer
18
+
19
+ pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
20
+
21
+
22
+ from fish_speech.i18n import i18n
23
+ from fish_speech.text.chn_text_norm.text import Text as ChnNormedText
24
+ from fish_speech.utils import autocast_exclude_mps
25
+ from tools.api import decode_vq_tokens, encode_reference
26
+ from tools.auto_rerank import batch_asr, calculate_wer, is_chinese, load_model
27
+ from tools.llama.generate import (
28
+ GenerateRequest,
29
+ GenerateResponse,
30
+ WrappedGenerateResponse,
31
+ launch_thread_safe_queue,
32
+ )
33
+ from tools.vqgan.inference import load_model as load_decoder_model
34
+
35
+ # Make einx happy
36
+ os.environ["EINX_FILTER_TRACEBACK"] = "false"
37
+
38
+
39
+ HEADER_MD = f"""# Fish Speech
40
+
41
+ {i18n("A text-to-speech model based on VQ-GAN and Llama developed by [Fish Audio](https://fish.audio).")}
42
+
43
+ {i18n("You can find the source code [here](https://github.com/fishaudio/fish-speech) and models [here](https://huggingface.co/fishaudio/fish-speech-1.4).")}
44
+
45
+ {i18n("Related code and weights are released under CC BY-NC-SA 4.0 License.")}
46
+
47
+ {i18n("We are not responsible for any misuse of the model, please consider your local laws and regulations before using it.")}
48
+ """
49
+
50
+ TEXTBOX_PLACEHOLDER = i18n("Put your text here.")
51
+ SPACE_IMPORTED = False
52
+
53
+
54
+ def build_html_error_message(error):
55
+ return f"""
56
+ <div style="color: red;
57
+ font-weight: bold;">
58
+ {html.escape(str(error))}
59
+ </div>
60
+ """
61
+
62
+
63
+ @torch.inference_mode()
64
+ def inference(
65
+ text,
66
+ enable_reference_audio,
67
+ reference_audio,
68
+ reference_text,
69
+ max_new_tokens,
70
+ chunk_length,
71
+ top_p,
72
+ repetition_penalty,
73
+ temperature,
74
+ streaming=False,
75
+ ):
76
+ if args.max_gradio_length > 0 and len(text) > args.max_gradio_length:
77
+ return (
78
+ None,
79
+ None,
80
+ i18n("Text is too long, please keep it under {} characters.").format(
81
+ args.max_gradio_length
82
+ ),
83
+ )
84
+
85
+ # Parse reference audio aka prompt
86
+ prompt_tokens = encode_reference(
87
+ decoder_model=decoder_model,
88
+ reference_audio=reference_audio,
89
+ enable_reference_audio=enable_reference_audio,
90
+ )
91
+
92
+ # LLAMA Inference
93
+ request = dict(
94
+ device=decoder_model.device,
95
+ max_new_tokens=max_new_tokens,
96
+ text=text,
97
+ top_p=top_p,
98
+ repetition_penalty=repetition_penalty,
99
+ temperature=temperature,
100
+ compile=args.compile,
101
+ iterative_prompt=chunk_length > 0,
102
+ chunk_length=chunk_length,
103
+ max_length=2048,
104
+ prompt_tokens=prompt_tokens if enable_reference_audio else None,
105
+ prompt_text=reference_text if enable_reference_audio else None,
106
+ )
107
+
108
+ response_queue = queue.Queue()
109
+ llama_queue.put(
110
+ GenerateRequest(
111
+ request=request,
112
+ response_queue=response_queue,
113
+ )
114
+ )
115
+
116
+ if streaming:
117
+ yield wav_chunk_header(), None, None
118
+
119
+ segments = []
120
+
121
+ while True:
122
+ result: WrappedGenerateResponse = response_queue.get()
123
+ if result.status == "error":
124
+ yield None, None, build_html_error_message(result.response)
125
+ break
126
+
127
+ result: GenerateResponse = result.response
128
+ if result.action == "next":
129
+ break
130
+
131
+ with autocast_exclude_mps(
132
+ device_type=decoder_model.device.type, dtype=args.precision
133
+ ):
134
+ fake_audios = decode_vq_tokens(
135
+ decoder_model=decoder_model,
136
+ codes=result.codes,
137
+ )
138
+
139
+ fake_audios = fake_audios.float().cpu().numpy()
140
+ segments.append(fake_audios)
141
+
142
+ if streaming:
143
+ yield (fake_audios * 32768).astype(np.int16).tobytes(), None, None
144
+
145
+ if len(segments) == 0:
146
+ return (
147
+ None,
148
+ None,
149
+ build_html_error_message(
150
+ i18n("No audio generated, please check the input text.")
151
+ ),
152
+ )
153
+
154
+ # No matter streaming or not, we need to return the final audio
155
+ audio = np.concatenate(segments, axis=0)
156
+ yield None, (decoder_model.spec_transform.sample_rate, audio), None
157
+
158
+ if torch.cuda.is_available():
159
+ torch.cuda.empty_cache()
160
+ gc.collect()
161
+
162
+
163
+ def inference_with_auto_rerank(
164
+ text,
165
+ enable_reference_audio,
166
+ reference_audio,
167
+ reference_text,
168
+ max_new_tokens,
169
+ chunk_length,
170
+ top_p,
171
+ repetition_penalty,
172
+ temperature,
173
+ use_auto_rerank,
174
+ streaming=False,
175
+ ):
176
+
177
+ max_attempts = 2 if use_auto_rerank else 1
178
+ best_wer = float("inf")
179
+ best_audio = None
180
+ best_sample_rate = None
181
+
182
+ for attempt in range(max_attempts):
183
+ audio_generator = inference(
184
+ text,
185
+ enable_reference_audio,
186
+ reference_audio,
187
+ reference_text,
188
+ max_new_tokens,
189
+ chunk_length,
190
+ top_p,
191
+ repetition_penalty,
192
+ temperature,
193
+ streaming=False,
194
+ )
195
+
196
+ # 获取音频数据
197
+ for _ in audio_generator:
198
+ pass
199
+ _, (sample_rate, audio), message = _
200
+
201
+ if audio is None:
202
+ return None, None, message
203
+
204
+ if not use_auto_rerank:
205
+ return None, (sample_rate, audio), None
206
+
207
+ asr_result = batch_asr(asr_model, [audio], sample_rate)[0]
208
+ wer = calculate_wer(text, asr_result["text"])
209
+ if wer <= 0.3 and not asr_result["huge_gap"]:
210
+ return None, (sample_rate, audio), None
211
+
212
+ if wer < best_wer:
213
+ best_wer = wer
214
+ best_audio = audio
215
+ best_sample_rate = sample_rate
216
+
217
+ if attempt == max_attempts - 1:
218
+ break
219
+
220
+ return None, (best_sample_rate, best_audio), None
221
+
222
+
223
+ inference_stream = partial(inference, streaming=True)
224
+
225
+ n_audios = 4
226
+
227
+ global_audio_list = []
228
+ global_error_list = []
229
+
230
+
231
+ def inference_wrapper(
232
+ text,
233
+ enable_reference_audio,
234
+ reference_audio,
235
+ reference_text,
236
+ max_new_tokens,
237
+ chunk_length,
238
+ top_p,
239
+ repetition_penalty,
240
+ temperature,
241
+ batch_infer_num,
242
+ if_load_asr_model,
243
+ ):
244
+ audios = []
245
+ errors = []
246
+
247
+ for _ in range(batch_infer_num):
248
+ result = inference_with_auto_rerank(
249
+ text,
250
+ enable_reference_audio,
251
+ reference_audio,
252
+ reference_text,
253
+ max_new_tokens,
254
+ chunk_length,
255
+ top_p,
256
+ repetition_penalty,
257
+ temperature,
258
+ if_load_asr_model,
259
+ )
260
+
261
+ _, audio_data, error_message = result
262
+
263
+ audios.append(
264
+ gr.Audio(value=audio_data if audio_data else None, visible=True),
265
+ )
266
+ errors.append(
267
+ gr.HTML(value=error_message if error_message else None, visible=True),
268
+ )
269
+
270
+ for _ in range(batch_infer_num, n_audios):
271
+ audios.append(
272
+ gr.Audio(value=None, visible=False),
273
+ )
274
+ errors.append(
275
+ gr.HTML(value=None, visible=False),
276
+ )
277
+
278
+ return None, *audios, *errors
279
+
280
+
281
+ def wav_chunk_header(sample_rate=44100, bit_depth=16, channels=1):
282
+ buffer = io.BytesIO()
283
+
284
+ with wave.open(buffer, "wb") as wav_file:
285
+ wav_file.setnchannels(channels)
286
+ wav_file.setsampwidth(bit_depth // 8)
287
+ wav_file.setframerate(sample_rate)
288
+
289
+ wav_header_bytes = buffer.getvalue()
290
+ buffer.close()
291
+ return wav_header_bytes
292
+
293
+
294
+ def normalize_text(user_input, use_normalization):
295
+ if use_normalization:
296
+ return ChnNormedText(raw_text=user_input).normalize()
297
+ else:
298
+ return user_input
299
+
300
+
301
+ asr_model = None
302
+
303
+
304
+ def change_if_load_asr_model(if_load):
305
+ global asr_model
306
+
307
+ if if_load:
308
+ gr.Warning("Loading faster whisper model...")
309
+ if asr_model is None:
310
+ asr_model = load_model()
311
+ return gr.Checkbox(label="Unload faster whisper model", value=if_load)
312
+
313
+ if if_load is False:
314
+ gr.Warning("Unloading faster whisper model...")
315
+ del asr_model
316
+ asr_model = None
317
+ if torch.cuda.is_available():
318
+ torch.cuda.empty_cache()
319
+ gc.collect()
320
+ return gr.Checkbox(label="Load faster whisper model", value=if_load)
321
+
322
+
323
+ def change_if_auto_label(if_load, if_auto_label, enable_ref, ref_audio, ref_text):
324
+ if if_load and asr_model is not None:
325
+ if (
326
+ if_auto_label
327
+ and enable_ref
328
+ and ref_audio is not None
329
+ and ref_text.strip() == ""
330
+ ):
331
+ data, sample_rate = librosa.load(ref_audio)
332
+ res = batch_asr(asr_model, [data], sample_rate)[0]
333
+ ref_text = res["text"]
334
+ else:
335
+ gr.Warning("Whisper model not loaded!")
336
+
337
+ return gr.Textbox(value=ref_text)
338
+
339
+
340
+ def build_app():
341
+ with gr.Blocks(theme=gr.themes.Base()) as app:
342
+ gr.Markdown(HEADER_MD)
343
+
344
+ # Use light theme by default
345
+ app.load(
346
+ None,
347
+ None,
348
+ js="() => {const params = new URLSearchParams(window.location.search);if (!params.has('__theme')) {params.set('__theme', '%s');window.location.search = params.toString();}}"
349
+ % args.theme,
350
+ )
351
+
352
+ # Inference
353
+ with gr.Row():
354
+ with gr.Column(scale=3):
355
+ text = gr.Textbox(
356
+ label=i18n("Input Text"), placeholder=TEXTBOX_PLACEHOLDER, lines=10
357
+ )
358
+ refined_text = gr.Textbox(
359
+ label=i18n("Realtime Transform Text"),
360
+ placeholder=i18n(
361
+ "Normalization Result Preview (Currently Only Chinese)"
362
+ ),
363
+ lines=5,
364
+ interactive=False,
365
+ )
366
+
367
+ with gr.Row():
368
+ if_refine_text = gr.Checkbox(
369
+ label=i18n("Text Normalization"),
370
+ value=False,
371
+ scale=1,
372
+ )
373
+
374
+ if_load_asr_model = gr.Checkbox(
375
+ label=i18n("Load / Unload ASR model for auto-reranking"),
376
+ value=False,
377
+ scale=3,
378
+ )
379
+
380
+ with gr.Row():
381
+ with gr.Tab(label=i18n("Advanced Config")):
382
+ chunk_length = gr.Slider(
383
+ label=i18n("Iterative Prompt Length, 0 means off"),
384
+ minimum=50,
385
+ maximum=300,
386
+ value=200,
387
+ step=8,
388
+ )
389
+
390
+ max_new_tokens = gr.Slider(
391
+ label=i18n("Maximum tokens per batch, 0 means no limit"),
392
+ minimum=0,
393
+ maximum=2048,
394
+ value=1024, # 0 means no limit
395
+ step=8,
396
+ )
397
+
398
+ top_p = gr.Slider(
399
+ label="Top-P",
400
+ minimum=0.6,
401
+ maximum=0.9,
402
+ value=0.7,
403
+ step=0.01,
404
+ )
405
+
406
+ repetition_penalty = gr.Slider(
407
+ label=i18n("Repetition Penalty"),
408
+ minimum=1,
409
+ maximum=1.5,
410
+ value=1.2,
411
+ step=0.01,
412
+ )
413
+
414
+ temperature = gr.Slider(
415
+ label="Temperature",
416
+ minimum=0.6,
417
+ maximum=0.9,
418
+ value=0.7,
419
+ step=0.01,
420
+ )
421
+
422
+ with gr.Tab(label=i18n("Reference Audio")):
423
+ gr.Markdown(
424
+ i18n(
425
+ "5 to 10 seconds of reference audio, useful for specifying speaker."
426
+ )
427
+ )
428
+
429
+ enable_reference_audio = gr.Checkbox(
430
+ label=i18n("Enable Reference Audio"),
431
+ )
432
+ reference_audio = gr.Audio(
433
+ label=i18n("Reference Audio"),
434
+ type="filepath",
435
+ )
436
+ with gr.Row():
437
+ if_auto_label = gr.Checkbox(
438
+ label=i18n("Auto Labeling"),
439
+ min_width=100,
440
+ scale=0,
441
+ value=False,
442
+ )
443
+ reference_text = gr.Textbox(
444
+ label=i18n("Reference Text"),
445
+ lines=1,
446
+ placeholder="在一无所知中,梦里的一天结束了,一个新的「轮回」便会开始。",
447
+ value="",
448
+ )
449
+ with gr.Tab(label=i18n("Batch Inference")):
450
+ batch_infer_num = gr.Slider(
451
+ label="Batch infer nums",
452
+ minimum=1,
453
+ maximum=n_audios,
454
+ step=1,
455
+ value=1,
456
+ )
457
+
458
+ with gr.Column(scale=3):
459
+ for _ in range(n_audios):
460
+ with gr.Row():
461
+ error = gr.HTML(
462
+ label=i18n("Error Message"),
463
+ visible=True if _ == 0 else False,
464
+ )
465
+ global_error_list.append(error)
466
+ with gr.Row():
467
+ audio = gr.Audio(
468
+ label=i18n("Generated Audio"),
469
+ type="numpy",
470
+ interactive=False,
471
+ visible=True if _ == 0 else False,
472
+ )
473
+ global_audio_list.append(audio)
474
+
475
+ with gr.Row():
476
+ stream_audio = gr.Audio(
477
+ label=i18n("Streaming Audio"),
478
+ streaming=True,
479
+ autoplay=True,
480
+ interactive=False,
481
+ show_download_button=True,
482
+ )
483
+ with gr.Row():
484
+ with gr.Column(scale=3):
485
+ generate = gr.Button(
486
+ value="\U0001F3A7 " + i18n("Generate"), variant="primary"
487
+ )
488
+ generate_stream = gr.Button(
489
+ value="\U0001F3A7 " + i18n("Streaming Generate"),
490
+ variant="primary",
491
+ )
492
+
493
+ text.input(
494
+ fn=normalize_text, inputs=[text, if_refine_text], outputs=[refined_text]
495
+ )
496
+
497
+ if_load_asr_model.change(
498
+ fn=change_if_load_asr_model,
499
+ inputs=[if_load_asr_model],
500
+ outputs=[if_load_asr_model],
501
+ )
502
+
503
+ if_auto_label.change(
504
+ fn=lambda: gr.Textbox(value=""),
505
+ inputs=[],
506
+ outputs=[reference_text],
507
+ ).then(
508
+ fn=change_if_auto_label,
509
+ inputs=[
510
+ if_load_asr_model,
511
+ if_auto_label,
512
+ enable_reference_audio,
513
+ reference_audio,
514
+ reference_text,
515
+ ],
516
+ outputs=[reference_text],
517
+ )
518
+
519
+ # # Submit
520
+ generate.click(
521
+ inference_wrapper,
522
+ [
523
+ refined_text,
524
+ enable_reference_audio,
525
+ reference_audio,
526
+ reference_text,
527
+ max_new_tokens,
528
+ chunk_length,
529
+ top_p,
530
+ repetition_penalty,
531
+ temperature,
532
+ batch_infer_num,
533
+ if_load_asr_model,
534
+ ],
535
+ [stream_audio, *global_audio_list, *global_error_list],
536
+ concurrency_limit=1,
537
+ )
538
+
539
+ generate_stream.click(
540
+ inference_stream,
541
+ [
542
+ refined_text,
543
+ enable_reference_audio,
544
+ reference_audio,
545
+ reference_text,
546
+ max_new_tokens,
547
+ chunk_length,
548
+ top_p,
549
+ repetition_penalty,
550
+ temperature,
551
+ ],
552
+ [stream_audio, global_audio_list[0], global_error_list[0]],
553
+ concurrency_limit=10,
554
+ )
555
+ return app
556
+
557
+
558
+ def parse_args():
559
+ parser = ArgumentParser()
560
+ parser.add_argument(
561
+ "--llama-checkpoint-path",
562
+ type=Path,
563
+ default="checkpoints/fish-speech-1.4",
564
+ )
565
+ parser.add_argument(
566
+ "--decoder-checkpoint-path",
567
+ type=Path,
568
+ default="checkpoints/fish-speech-1.4/firefly-gan-vq-fsq-8x1024-21hz-generator.pth",
569
+ )
570
+ parser.add_argument("--decoder-config-name", type=str, default="firefly_gan_vq")
571
+ parser.add_argument("--device", type=str, default="cuda")
572
+ parser.add_argument("--half", action="store_true")
573
+ parser.add_argument("--compile", action="store_true")
574
+ parser.add_argument("--max-gradio-length", type=int, default=0)
575
+ parser.add_argument("--theme", type=str, default="light")
576
+
577
+ return parser.parse_args()
578
+
579
+
580
+ if __name__ == "__main__":
581
+ args = parse_args()
582
+ args.precision = torch.half if args.half else torch.bfloat16
583
+
584
+ logger.info("Loading Llama model...")
585
+ llama_queue = launch_thread_safe_queue(
586
+ checkpoint_path=args.llama_checkpoint_path,
587
+ device=args.device,
588
+ precision=args.precision,
589
+ compile=args.compile,
590
+ )
591
+ logger.info("Llama model loaded, loading VQ-GAN model...")
592
+
593
+ decoder_model = load_decoder_model(
594
+ config_name=args.decoder_config_name,
595
+ checkpoint_path=args.decoder_checkpoint_path,
596
+ device=args.device,
597
+ )
598
+
599
+ logger.info("Decoder model loaded, warming up...")
600
+
601
+ # Dry run to check if the model is loaded correctly and avoid the first-time latency
602
+ list(
603
+ inference(
604
+ text="Hello, world!",
605
+ enable_reference_audio=False,
606
+ reference_audio=None,
607
+ reference_text="",
608
+ max_new_tokens=0,
609
+ chunk_length=100,
610
+ top_p=0.7,
611
+ repetition_penalty=1.2,
612
+ temperature=0.7,
613
+ )
614
+ )
615
+
616
+ logger.info("Warming up done, launching the web UI...")
617
+
618
+ app = build_app()
619
+ app.launch(show_api=True)
tools/whisper_asr.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Used to transcribe all audio files in one folder into another folder.
3
+ e.g.
4
+ Directory structure:
5
+ --pre_data_root
6
+ ----SP_1
7
+ ------01.wav
8
+ ------02.wav
9
+ ------......
10
+ ----SP_2
11
+ ------01.wav
12
+ ------02.wav
13
+ ------......
14
+ Use
15
+ python tools/whisper_asr.py --audio-dir pre_data_root/SP_1 --save-dir data/SP_1
16
+ to transcribe the first speaker.
17
+
18
+ Use
19
+ python tools/whisper_asr.py --audio-dir pre_data_root/SP_2 --save-dir data/SP_2
20
+ to transcribe the second speaker.
21
+
22
+ Note: Be aware of your audio sample rate, which defaults to 44.1kHz.
23
+ """
24
+
25
+ import re
26
+ from pathlib import Path
27
+
28
+ import click
29
+ import soundfile as sf
30
+ from faster_whisper import WhisperModel
31
+ from loguru import logger
32
+ from pydub import AudioSegment
33
+ from tqdm import tqdm
34
+
35
+ from tools.file import AUDIO_EXTENSIONS, list_files
36
+
37
+
38
+ @click.command()
39
+ @click.option("--model-size", default="large-v3", help="Size of the Whisper model")
40
+ @click.option(
41
+ "--compute-type",
42
+ default="float16",
43
+ help="Computation Precision of the Whisper model [float16 / int8_float16 / int8]",
44
+ )
45
+ @click.option("--audio-dir", required=True, help="Directory containing audio files")
46
+ @click.option(
47
+ "--save-dir", required=True, help="Directory to save processed audio files"
48
+ )
49
+ @click.option(
50
+ "--sample-rate",
51
+ default=44100,
52
+ type=int,
53
+ help="Output sample rate, default to input sample rate",
54
+ )
55
+ @click.option("--device", default="cuda", help="Device to use [cuda / cpu]")
56
+ @click.option("--language", default="auto", help="Language of the transcription")
57
+ @click.option("--initial-prompt", default=None, help="Initial prompt for transcribing")
58
+ def main(
59
+ model_size,
60
+ compute_type,
61
+ audio_dir,
62
+ save_dir,
63
+ sample_rate,
64
+ device,
65
+ language,
66
+ initial_prompt,
67
+ ):
68
+ logger.info("Loading / Downloading Faster Whisper model...")
69
+
70
+ model = WhisperModel(
71
+ model_size,
72
+ device=device,
73
+ compute_type=compute_type,
74
+ download_root="faster_whisper",
75
+ )
76
+
77
+ logger.info("Model loaded.")
78
+
79
+ save_path = Path(save_dir)
80
+ save_path.mkdir(parents=True, exist_ok=True)
81
+
82
+ audio_files = list_files(
83
+ path=audio_dir, extensions=AUDIO_EXTENSIONS, recursive=True
84
+ )
85
+
86
+ for file_path in tqdm(audio_files, desc="Processing audio file"):
87
+ file_stem = file_path.stem
88
+ file_suffix = file_path.suffix
89
+
90
+ rel_path = Path(file_path).relative_to(audio_dir)
91
+ (save_path / rel_path.parent).mkdir(parents=True, exist_ok=True)
92
+
93
+ audio = AudioSegment.from_file(file_path)
94
+
95
+ segments, info = model.transcribe(
96
+ file_path,
97
+ beam_size=5,
98
+ language=None if language == "auto" else language,
99
+ initial_prompt=initial_prompt,
100
+ )
101
+
102
+ print(
103
+ "Detected language '%s' with probability %f"
104
+ % (info.language, info.language_probability)
105
+ )
106
+ print("Total len(ms): ", len(audio))
107
+
108
+ whole_text = None
109
+ for segment in segments:
110
+ id, start, end, text = (
111
+ segment.id,
112
+ segment.start,
113
+ segment.end,
114
+ segment.text,
115
+ )
116
+ print("Segment %03d [%.2fs -> %.2fs] %s" % (id, start, end, text))
117
+ if not whole_text:
118
+ whole_text = text
119
+ else:
120
+ whole_text += ", " + text
121
+
122
+ whole_text += "."
123
+
124
+ audio_save_path = save_path / rel_path.parent / f"{file_stem}{file_suffix}"
125
+ audio.export(audio_save_path, format=file_suffix[1:])
126
+ print(f"Exported {audio_save_path}")
127
+
128
+ transcript_save_path = save_path / rel_path.parent / f"{file_stem}.lab"
129
+ with open(
130
+ transcript_save_path,
131
+ "w",
132
+ encoding="utf-8",
133
+ ) as f:
134
+ f.write(whole_text)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ main()
139
+ exit(0)
140
+
141
+ audio = AudioSegment.from_wav(
142
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav"
143
+ )
144
+
145
+ model_size = "large-v3"
146
+
147
+ model = WhisperModel(
148
+ model_size,
149
+ device="cuda",
150
+ compute_type="float16",
151
+ download_root="faster_whisper",
152
+ )
153
+
154
+ segments, info = model.transcribe(
155
+ r"D:\PythonProject\原神语音中文\胡桃\vo_hutao_draw_appear.wav",
156
+ beam_size=5,
157
+ )
158
+
159
+ print(
160
+ "Detected language '%s' with probability %f"
161
+ % (info.language, info.language_probability)
162
+ )
163
+ print("Total len(ms): ", len(audio))
164
+
165
+ for i, segment in enumerate(segments):
166
+ print(
167
+ "Segment %03d [%.2fs -> %.2fs] %s"
168
+ % (i, segment.start, segment.end, segment.text)
169
+ )
170
+ start_ms = int(segment.start * 1000)
171
+ end_ms = int(segment.end * 1000)
172
+ segment_audio = audio[start_ms:end_ms]
173
+ segment_audio.export(f"segment_{i:03d}.wav", format="wav")
174
+ print(f"Exported segment_{i:03d}.wav")
175
+
176
+ print("All segments have been exported.")