Spaces:
Runtime error
Runtime error
Aboubacar OUATTARA - kaira
commited on
Commit
·
1b0b842
1
Parent(s):
05fb637
use custom tts
Browse files- .gitattributes +1 -0
- app.py +49 -13
- requirements.txt +8 -5
- tts.py +395 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
import concurrent
|
|
|
|
|
|
|
|
|
2 |
import spaces
|
3 |
from transformers import pipeline
|
4 |
import gradio as gr
|
@@ -7,6 +11,7 @@ import torchaudio
|
|
7 |
from resemble_enhance.enhancer.inference import denoise, enhance
|
8 |
|
9 |
from flore200_codes import flores_codes
|
|
|
10 |
|
11 |
# Check if CUDA is available
|
12 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
@@ -16,8 +21,8 @@ translation_model = "oza75/nllb-600M-mt-french-bambara"
|
|
16 |
translator = pipeline("translation", model=translation_model, max_length=512)
|
17 |
|
18 |
# Text-to-Speech pipeline
|
19 |
-
tts_model = "oza75/bambara-tts
|
20 |
-
tts =
|
21 |
|
22 |
|
23 |
# Function to translate text to Bambara
|
@@ -29,11 +34,30 @@ def translate_to_bambara(text, src_lang):
|
|
29 |
|
30 |
# Function to convert text to speech
|
31 |
@spaces.GPU
|
32 |
-
def text_to_speech(bambara_text):
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
return audio, sr
|
38 |
|
39 |
|
@@ -64,14 +88,25 @@ def enhance_speech(audio_array, sampling_rate, solver, nfe, tau, denoise_before_
|
|
64 |
|
65 |
|
66 |
# Define the Gradio interface
|
67 |
-
def _fn(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
source_lang = flores_codes[src_lang]
|
69 |
|
70 |
# Step 1: Translate the text to Bambara
|
71 |
bambara_text = translate_to_bambara(text, source_lang)
|
72 |
|
73 |
-
# Step 2: Convert the translated text to speech
|
74 |
-
|
|
|
|
|
|
|
75 |
|
76 |
# Step 3: Enhance the audio
|
77 |
denoised_audio, enhanced_audio = enhance_speech(
|
@@ -95,13 +130,14 @@ def main():
|
|
95 |
fn=_fn,
|
96 |
inputs=[
|
97 |
gr.Dropdown(label="Source Language", choices=lang_codes, value='French'),
|
98 |
-
gr.Textbox(label="Text to Translate"),
|
|
|
99 |
gr.Dropdown(
|
100 |
choices=["Midpoint", "RK4", "Euler"], value="Midpoint",
|
101 |
label="ODE Solver (Midpoint is recommended)"
|
102 |
),
|
103 |
gr.Slider(minimum=1, maximum=128, value=64, step=1, label="Number of Function Evaluations"),
|
104 |
-
gr.Slider(minimum=0, maximum=1, value=0.5, step=0.01, label="Prior Temperature"),
|
105 |
gr.Checkbox(value=False, label="Denoise Before Enhancement")
|
106 |
],
|
107 |
outputs=[
|
@@ -114,7 +150,7 @@ def main():
|
|
114 |
description="Translate text to Bambara and convert it to speech with options to enhance audio quality."
|
115 |
)
|
116 |
|
117 |
-
app.launch()
|
118 |
|
119 |
|
120 |
if __name__ == "__main__":
|
|
|
1 |
import concurrent
|
2 |
+
import os
|
3 |
+
import tempfile
|
4 |
+
from typing import Optional, Tuple
|
5 |
+
|
6 |
import spaces
|
7 |
from transformers import pipeline
|
8 |
import gradio as gr
|
|
|
11 |
from resemble_enhance.enhancer.inference import denoise, enhance
|
12 |
|
13 |
from flore200_codes import flores_codes
|
14 |
+
from tts import BambaraTTS
|
15 |
|
16 |
# Check if CUDA is available
|
17 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
21 |
translator = pipeline("translation", model=translation_model, max_length=512)
|
22 |
|
23 |
# Text-to-Speech pipeline
|
24 |
+
tts_model = "oza75/bambara-tts"
|
25 |
+
tts = BambaraTTS(tts_model)
|
26 |
|
27 |
|
28 |
# Function to translate text to Bambara
|
|
|
34 |
|
35 |
# Function to convert text to speech
|
36 |
@spaces.GPU
|
37 |
+
def text_to_speech(bambara_text, reference_audio: Optional[Tuple] = None):
|
38 |
+
if reference_audio is not None:
|
39 |
+
ref_sr, ref_audio = reference_audio
|
40 |
+
ref_audio = torch.from_numpy(ref_audio)
|
41 |
+
|
42 |
+
# Add a channel dimension if the audio is 1D
|
43 |
+
if ref_audio.ndim == 1:
|
44 |
+
ref_audio = ref_audio.unsqueeze(0)
|
45 |
+
|
46 |
+
# Save the reference audio to a temporary file if it's not None
|
47 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp:
|
48 |
+
torchaudio.save(tmp.name, ref_audio, ref_sr)
|
49 |
+
tmp_path = tmp.name
|
50 |
+
|
51 |
+
# Use the temporary file as the speaker reference
|
52 |
+
sr, audio = tts.text_to_speech(bambara_text, speaker_reference_wav_path=tmp_path)
|
53 |
+
|
54 |
+
# Clean up the temporary file
|
55 |
+
os.unlink(tmp_path)
|
56 |
+
else:
|
57 |
+
# If no reference audio provided, proceed with the default
|
58 |
+
sr, audio = tts.text_to_speech(bambara_text)
|
59 |
+
|
60 |
+
audio = audio.mean(dim=0)
|
61 |
return audio, sr
|
62 |
|
63 |
|
|
|
88 |
|
89 |
|
90 |
# Define the Gradio interface
|
91 |
+
def _fn(
|
92 |
+
src_lang,
|
93 |
+
text,
|
94 |
+
reference_audio=None,
|
95 |
+
solver="Midpoint",
|
96 |
+
nfe=64,
|
97 |
+
prior_temp=0.5,
|
98 |
+
denoise_before_enhancement=False
|
99 |
+
):
|
100 |
source_lang = flores_codes[src_lang]
|
101 |
|
102 |
# Step 1: Translate the text to Bambara
|
103 |
bambara_text = translate_to_bambara(text, source_lang)
|
104 |
|
105 |
+
# Step 2: Convert the translated text to speech with reference audio
|
106 |
+
if reference_audio is not None:
|
107 |
+
audio_array, sampling_rate = text_to_speech(bambara_text, reference_audio)
|
108 |
+
else:
|
109 |
+
audio_array, sampling_rate = text_to_speech(bambara_text)
|
110 |
|
111 |
# Step 3: Enhance the audio
|
112 |
denoised_audio, enhanced_audio = enhance_speech(
|
|
|
130 |
fn=_fn,
|
131 |
inputs=[
|
132 |
gr.Dropdown(label="Source Language", choices=lang_codes, value='French'),
|
133 |
+
gr.Textbox(label="Text to Translate", lines=3),
|
134 |
+
gr.Audio(label="Clone your voice (optional)", type="numpy", format="wav"),
|
135 |
gr.Dropdown(
|
136 |
choices=["Midpoint", "RK4", "Euler"], value="Midpoint",
|
137 |
label="ODE Solver (Midpoint is recommended)"
|
138 |
),
|
139 |
gr.Slider(minimum=1, maximum=128, value=64, step=1, label="Number of Function Evaluations"),
|
140 |
+
gr.Slider(minimum=0.1, maximum=1, value=0.5, step=0.01, label="Prior Temperature"),
|
141 |
gr.Checkbox(value=False, label="Denoise Before Enhancement")
|
142 |
],
|
143 |
outputs=[
|
|
|
150 |
description="Translate text to Bambara and convert it to speech with options to enhance audio quality."
|
151 |
)
|
152 |
|
153 |
+
app.launch(share=False)
|
154 |
|
155 |
|
156 |
if __name__ == "__main__":
|
requirements.txt
CHANGED
@@ -1,6 +1,9 @@
|
|
1 |
-
transformers
|
2 |
-
gradio
|
3 |
-
torch
|
4 |
-
torchaudio
|
5 |
-
spaces
|
|
|
|
|
6 |
resemble-enhance==0.0.2.dev240104122303
|
|
|
|
1 |
+
transformers>=4.33.0
|
2 |
+
gradio~=4.8.0
|
3 |
+
torch~=2.1.1
|
4 |
+
torchaudio~=2.1.1
|
5 |
+
spaces~=0.26.1
|
6 |
+
deepspeed~=0.12.1
|
7 |
+
requests~=2.31.0
|
8 |
resemble-enhance==0.0.2.dev240104122303
|
9 |
+
git+https://github.com/oza75/coqui-TTS.git@prod
|
tts.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import requests
|
7 |
+
import torch
|
8 |
+
from typing import Optional, Tuple
|
9 |
+
|
10 |
+
from TTS.tts.configs.xtts_config import XttsConfig
|
11 |
+
from TTS.tts.models.xtts import Xtts
|
12 |
+
from TTS.tts.layers.xtts.tokenizer import VoiceBpeTokenizer, basic_cleaners
|
13 |
+
from coqpit import Coqpit
|
14 |
+
from huggingface_hub import hf_hub_download, hf_hub_url
|
15 |
+
from tqdm import tqdm
|
16 |
+
|
17 |
+
|
18 |
+
def download_file_with_progress(url: str, destination: str):
|
19 |
+
"""
|
20 |
+
Downloads a file from a web URL with a progress bar.
|
21 |
+
"""
|
22 |
+
# Streaming GET request
|
23 |
+
response = requests.get(url, stream=True)
|
24 |
+
|
25 |
+
# Total size in bytes, set to zero if missing
|
26 |
+
total_size = int(response.headers.get('content-length', 0))
|
27 |
+
|
28 |
+
# Using tqdm to display progress
|
29 |
+
with open(destination, 'wb') as file, tqdm(desc=destination, total=total_size, unit='B', unit_scale=True,
|
30 |
+
unit_divisor=1024) as bar:
|
31 |
+
for data in response.iter_content(chunk_size=1024):
|
32 |
+
size = file.write(data)
|
33 |
+
bar.update(size)
|
34 |
+
|
35 |
+
|
36 |
+
class VoiceBambaraTextPreprocessor:
|
37 |
+
def preprocess_batch(self, texts):
|
38 |
+
return [self.preprocess(text) for text in texts]
|
39 |
+
|
40 |
+
def preprocess(self, text: str) -> str:
|
41 |
+
text = text.lower()
|
42 |
+
text = self.expand_number(text)
|
43 |
+
text = self.transliterate_bambara(text)
|
44 |
+
|
45 |
+
return text
|
46 |
+
|
47 |
+
def transliterate_bambara(self, text):
|
48 |
+
"""
|
49 |
+
Transliterate Bambara text using a specified mapping of special characters.
|
50 |
+
|
51 |
+
Parameters:
|
52 |
+
- text (str): The original Bambara text.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
- str: The transliterated text.
|
56 |
+
"""
|
57 |
+
bambara_transliteration = {
|
58 |
+
'ɲ': 'ny',
|
59 |
+
'ɛ': 'è',
|
60 |
+
'ɔ': 'o',
|
61 |
+
'ŋ': 'ng',
|
62 |
+
'ɟ': 'j',
|
63 |
+
'ʔ': "'",
|
64 |
+
'ɣ': 'gh',
|
65 |
+
'ʃ': 'sh',
|
66 |
+
'ߒ': 'n',
|
67 |
+
'ߎ': "u",
|
68 |
+
}
|
69 |
+
|
70 |
+
# Perform the transliteration
|
71 |
+
transliterated_text = "".join(bambara_transliteration.get(char, char) for char in text)
|
72 |
+
|
73 |
+
return transliterated_text
|
74 |
+
|
75 |
+
def expand_number(self, text):
|
76 |
+
"""
|
77 |
+
Normalize Bambara text for TTS by replacing numerical figures with their word equivalents.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
text (str): The text to be normalized.
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
str: The normalized Bambara text.
|
84 |
+
"""
|
85 |
+
|
86 |
+
# A regex pattern to match all numbers
|
87 |
+
number_pattern = re.compile(r'\b\d+\b')
|
88 |
+
|
89 |
+
# Function to replace each number with its Bambara text
|
90 |
+
def replace_number_with_text(match):
|
91 |
+
number = int(match.group())
|
92 |
+
return self.number_to_bambara(number)
|
93 |
+
|
94 |
+
# Replace each number in the text with its Bambara word equivalent
|
95 |
+
normalized_text = number_pattern.sub(replace_number_with_text, text)
|
96 |
+
|
97 |
+
return normalized_text
|
98 |
+
|
99 |
+
def number_to_bambara(self, n):
|
100 |
+
|
101 |
+
"""
|
102 |
+
Convert a number into its textual representation in Bambara using recursion.
|
103 |
+
Args:
|
104 |
+
n (int): The number to be converted.
|
105 |
+
Returns:
|
106 |
+
str: The number expressed in Bambara text.
|
107 |
+
Examples:
|
108 |
+
>>> number_to_bambara(123)
|
109 |
+
'kɛmɛ ni mugan ni saba'
|
110 |
+
Notes:
|
111 |
+
This function assumes that 'n' is a non-negative integer.
|
112 |
+
"""
|
113 |
+
|
114 |
+
# Bambara numbering rules
|
115 |
+
units = ["", "kɛlɛn", "fila", "saba", "naani", "duuru", "wɔrɔ", "wòlonwula", "sɛɛgin", "kɔnɔntɔn"]
|
116 |
+
tens = ["", "tan", "mugan", "bisaba", "binaani", "biduuru", "biwɔrɔ", "biwòlonfila", "bisɛɛgin", "bikɔnɔntɔn"]
|
117 |
+
hundreds = ["", "kɛmɛ"]
|
118 |
+
thousands = ["", "waga"]
|
119 |
+
millions = ["", "milyɔn"]
|
120 |
+
|
121 |
+
# Handle zero explicitly
|
122 |
+
if n == 0:
|
123 |
+
return "" # bambara does not support zero
|
124 |
+
|
125 |
+
if n < 10:
|
126 |
+
return units[n]
|
127 |
+
elif n < 100:
|
128 |
+
return tens[n // 10] + (" ni " + self.number_to_bambara(n % 10) if n % 10 > 0 else "")
|
129 |
+
elif n < 1000:
|
130 |
+
return hundreds[1] + (" " + self.number_to_bambara(n // 100) if n >= 200 else "") + (
|
131 |
+
" ni " + self.number_to_bambara(n % 100) if n % 100 > 0 else "")
|
132 |
+
elif n < 1_000_000:
|
133 |
+
return thousands[1] + " " + self.number_to_bambara(n // 1000) + (
|
134 |
+
" ni " + self.number_to_bambara(n % 1000) if n % 1000 > 0 else "")
|
135 |
+
else:
|
136 |
+
return millions[1] + " " + self.number_to_bambara(n // 1_000_000) + (
|
137 |
+
" ni " + self.number_to_bambara(n % 1_000_000) if n % 1_000_000 > 0 else "")
|
138 |
+
|
139 |
+
|
140 |
+
class BambaraTokenizer(VoiceBpeTokenizer):
|
141 |
+
"""
|
142 |
+
A tokenizer for the Bambara language that extends the VoiceBpeTokenizer.
|
143 |
+
|
144 |
+
Attributes:
|
145 |
+
preprocessor: An instance of VoiceBambaraTextPreprocessor for text preprocessing.
|
146 |
+
char_limits: A dictionary to hold character limits for languages.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, vocab_file: Optional[str] = None):
|
150 |
+
"""
|
151 |
+
Initializes the BambaraTokenizer with a given vocabulary file.
|
152 |
+
|
153 |
+
Args:
|
154 |
+
vocab_file: The path to the vocabulary file, defaults to None.
|
155 |
+
"""
|
156 |
+
super().__init__(vocab_file)
|
157 |
+
self.preprocessor = VoiceBambaraTextPreprocessor()
|
158 |
+
self.char_limits['bm'] = 200 # Set character limit for Bambara language
|
159 |
+
|
160 |
+
def preprocess_text(self, txt: str, lang: str) -> str:
|
161 |
+
"""
|
162 |
+
Preprocesses the input text based on the language.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
txt: The text to preprocess.
|
166 |
+
lang: The language code of the text.
|
167 |
+
|
168 |
+
Returns:
|
169 |
+
The preprocessed text.
|
170 |
+
"""
|
171 |
+
# Delegate preprocessing to the parent class for non-Bambara languages
|
172 |
+
if lang != "bm":
|
173 |
+
return super().preprocess_text(txt, lang)
|
174 |
+
|
175 |
+
# Apply Bambara-specific preprocessing
|
176 |
+
txt = self.preprocessor.preprocess(txt)
|
177 |
+
txt = basic_cleaners(txt)
|
178 |
+
return txt
|
179 |
+
|
180 |
+
|
181 |
+
class BambaraXtts(Xtts):
|
182 |
+
"""
|
183 |
+
A class for the Bambara language that extends the Xtts class.
|
184 |
+
|
185 |
+
Attributes:
|
186 |
+
tokenizer: An instance of BambaraTokenizer.
|
187 |
+
"""
|
188 |
+
|
189 |
+
def __init__(self, config: Coqpit):
|
190 |
+
"""
|
191 |
+
Initializes the BambaraXtts with the provided configuration.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
config: An instance of Coqpit containing configuration settings.
|
195 |
+
"""
|
196 |
+
super().__init__(config)
|
197 |
+
self.tokenizer = BambaraTokenizer() # Initialize tokenizer for Bambara
|
198 |
+
self.init_models()
|
199 |
+
|
200 |
+
@classmethod
|
201 |
+
def init_from_config(cls, config: "XttsConfig", **kwargs) -> "BambaraXtts":
|
202 |
+
"""
|
203 |
+
Class method to create an instance of BambaraXtts from a configuration object.
|
204 |
+
|
205 |
+
Args:
|
206 |
+
config: An instance of XttsConfig containing configuration settings.
|
207 |
+
**kwargs: Additional keyword arguments.
|
208 |
+
|
209 |
+
Returns:
|
210 |
+
An instance of BambaraXtts.
|
211 |
+
"""
|
212 |
+
return cls(config)
|
213 |
+
|
214 |
+
|
215 |
+
class BambaraTTS:
|
216 |
+
"""
|
217 |
+
Bambara Text-to-Speech (TTS) class that initializes and uses a TTS model for the Bambara language.
|
218 |
+
|
219 |
+
Attributes:
|
220 |
+
language_code (str): The ISO language code for Bambara.
|
221 |
+
checkpoint_repo_or_dir (str): URL or local path to the model checkpoint directory.
|
222 |
+
local_dir (str): The directory to store downloaded checkpoints.
|
223 |
+
paths (dict): A dictionary of paths to model components.
|
224 |
+
config (XttsConfig): Configuration object for the TTS model.
|
225 |
+
model (BambaraXtts): The TTS model instance.
|
226 |
+
"""
|
227 |
+
|
228 |
+
def __init__(self, checkpoint_repo_or_dir: str, local_dir: Optional[str] = None):
|
229 |
+
"""
|
230 |
+
Initialize the BambaraTTS instance.
|
231 |
+
|
232 |
+
Args:
|
233 |
+
checkpoint_repo_or_dir: A string that represents either a Hugging Face hub repository
|
234 |
+
or a local directory where the TTS model checkpoint is located.
|
235 |
+
local_dir: An optional string representing a local directory path where model checkpoints
|
236 |
+
will be downloaded. If not specified, a default local directory is used based
|
237 |
+
on `checkpoint_repo_or_dir`.
|
238 |
+
|
239 |
+
The initialization process involves setting up local directories for model components,
|
240 |
+
ensuring the model checkpoint is available, and loading the model configuration and tokenizer.
|
241 |
+
"""
|
242 |
+
|
243 |
+
# Set the language code for Bambara
|
244 |
+
self.language_code = 'bm'
|
245 |
+
|
246 |
+
# Store the checkpoint location and local directory path
|
247 |
+
self.checkpoint_repo_or_dir = checkpoint_repo_or_dir
|
248 |
+
# If no local directory is provided, use the default based on the checkpoint
|
249 |
+
self.local_dir = local_dir if local_dir else self.default_local_dir(checkpoint_repo_or_dir)
|
250 |
+
|
251 |
+
# Initialize the paths for model components
|
252 |
+
self.paths = self.init_paths(self.local_dir)
|
253 |
+
|
254 |
+
# Ensure the model checkpoint is available locally
|
255 |
+
self.ensure_checkpoint_is_downloaded()
|
256 |
+
|
257 |
+
# Load the model configuration from a JSON file
|
258 |
+
self.config = XttsConfig()
|
259 |
+
self.config.load_json(self.paths['config.json'])
|
260 |
+
|
261 |
+
# Initialize the TTS model with the loaded configuration
|
262 |
+
self.model = BambaraXtts(self.config)
|
263 |
+
|
264 |
+
# Set up the tokenizer for the model, using the vocabulary file path
|
265 |
+
self.model.tokenizer = BambaraTokenizer(vocab_file=self.paths['vocab.json'])
|
266 |
+
|
267 |
+
# Load the model checkpoint into the initialized model
|
268 |
+
self.model.load_checkpoint(
|
269 |
+
self.config,
|
270 |
+
vocab_path="fake_vocab.json",
|
271 |
+
# The 'fake_vocab.json' is specified because the base model class might
|
272 |
+
# attempt to override our tokenizer if a vocab file is present
|
273 |
+
checkpoint_dir=self.local_dir,
|
274 |
+
use_deepspeed=torch.cuda.is_available() # Utilize DeepSpeed if CUDA is available
|
275 |
+
)
|
276 |
+
|
277 |
+
# Move the model to GPU if CUDA is available
|
278 |
+
if torch.cuda.is_available():
|
279 |
+
self.model.cuda()
|
280 |
+
|
281 |
+
self.log_tokenizer()
|
282 |
+
|
283 |
+
def ensure_checkpoint_is_downloaded(self):
|
284 |
+
"""
|
285 |
+
Ensures that the model checkpoint is downloaded and available locally.
|
286 |
+
"""
|
287 |
+
if os.path.exists(self.checkpoint_repo_or_dir):
|
288 |
+
return
|
289 |
+
|
290 |
+
os.makedirs(self.local_dir, exist_ok=True)
|
291 |
+
self.log("Downloading checkpoint from the hub...")
|
292 |
+
|
293 |
+
for filename, filepath in self.paths.items():
|
294 |
+
if os.path.exists(filepath):
|
295 |
+
self.log(f"File {filepath} already exists. Skipping...")
|
296 |
+
continue
|
297 |
+
|
298 |
+
file_url = hf_hub_url(repo_id=self.checkpoint_repo_or_dir, filename=filename)
|
299 |
+
self.log(f"Downloading {filename} from {file_url}")
|
300 |
+
download_file_with_progress(file_url, filepath)
|
301 |
+
|
302 |
+
self.log("Checkpoint downloaded successfully!")
|
303 |
+
|
304 |
+
def default_local_dir(self, checkpoint_repo_or_dir: str) -> str:
|
305 |
+
"""
|
306 |
+
Generates a default local directory path for storing the model checkpoint.
|
307 |
+
|
308 |
+
Args:
|
309 |
+
checkpoint_repo_or_dir: The original checkpoint repository or directory path.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
The default local directory path.
|
313 |
+
"""
|
314 |
+
if os.path.exists(checkpoint_repo_or_dir):
|
315 |
+
return checkpoint_repo_or_dir
|
316 |
+
|
317 |
+
model_path = f"models--{checkpoint_repo_or_dir.replace('/', '--')}"
|
318 |
+
local_dir = os.path.join(os.path.expanduser('~'), ".cache", "huggingface", "hub", model_path)
|
319 |
+
return local_dir.lower()
|
320 |
+
|
321 |
+
@staticmethod
|
322 |
+
def init_paths(local_dir: str) -> dict:
|
323 |
+
"""
|
324 |
+
Initializes paths to various model components based on the local directory.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
local_dir: The local directory where model components are stored.
|
328 |
+
|
329 |
+
Returns:
|
330 |
+
A dictionary with keys as component names and values as file paths.
|
331 |
+
"""
|
332 |
+
components = ['model.pth', 'config.json', 'vocab.json', 'dvae.pth', 'mel_stats.pth']
|
333 |
+
return {name: os.path.join(local_dir, name) for name in components}
|
334 |
+
|
335 |
+
def text_to_speech(
|
336 |
+
self,
|
337 |
+
text: str,
|
338 |
+
speaker_reference_wav_path: Optional[str] = None,
|
339 |
+
temperature: Optional[float] = 0.1,
|
340 |
+
enable_text_splitting: bool = False
|
341 |
+
) -> Tuple[int, torch.Tensor]:
|
342 |
+
"""
|
343 |
+
Converts text into speech audio.
|
344 |
+
|
345 |
+
Args:
|
346 |
+
text: The input text to be converted into speech.
|
347 |
+
speaker_reference_wav_path: A path to a reference WAV file for the speaker.
|
348 |
+
temperature: The temperature parameter for sampling.
|
349 |
+
enable_text_splitting: Flag to enable or disable text splitting.
|
350 |
+
|
351 |
+
Returns:
|
352 |
+
A tuple containing the sampling rate and the generated audio tensor.
|
353 |
+
"""
|
354 |
+
if speaker_reference_wav_path is None:
|
355 |
+
speaker_reference_wav_path = "reference_audios/male_2.wav"
|
356 |
+
self.log("Using default speaker reference ./audios/male_2.wav.")
|
357 |
+
|
358 |
+
self.log("Computing speaker latents...")
|
359 |
+
gpt_cond_latent, speaker_embedding = self.model.get_conditioning_latents(
|
360 |
+
audio_path=[speaker_reference_wav_path]
|
361 |
+
)
|
362 |
+
|
363 |
+
self.log("Starting inference...")
|
364 |
+
start_time = time.time()
|
365 |
+
out = self.model.inference(
|
366 |
+
text,
|
367 |
+
self.language_code,
|
368 |
+
gpt_cond_latent,
|
369 |
+
speaker_embedding,
|
370 |
+
temperature=temperature,
|
371 |
+
enable_text_splitting=enable_text_splitting
|
372 |
+
)
|
373 |
+
end_time = time.time()
|
374 |
+
|
375 |
+
audio = torch.tensor(out["wav"]).unsqueeze(0)
|
376 |
+
sampling_rate = self.config.model_args.output_sample_rate
|
377 |
+
|
378 |
+
self.log(f"Speech generated in {end_time - start_time:.2f} seconds.")
|
379 |
+
|
380 |
+
return sampling_rate, audio
|
381 |
+
|
382 |
+
def log(self, message: str):
|
383 |
+
"""
|
384 |
+
Logs a message to the console with a uniform format.
|
385 |
+
|
386 |
+
Args:
|
387 |
+
message: The message to be logged.
|
388 |
+
"""
|
389 |
+
print(f"[BambaraTTS] {message}")
|
390 |
+
|
391 |
+
def log_tokenizer(self):
|
392 |
+
"""
|
393 |
+
Logs the tokenizer information.
|
394 |
+
"""
|
395 |
+
self.log(f"Tokenizer: {self.model.tokenizer}")
|