bachvudinh commited on
Commit
0fecc29
·
1 Parent(s): a41273c

init commits

Browse files
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import spaces
4
+ import torchaudio
5
+ from whisperspeech.vq_stoks import RQBottleneckTransformer
6
+ from encodec.utils import convert_audio
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
8
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
9
+ from threading import Thread
10
+ import logging
11
+ import os
12
+ from generate_audio import (
13
+ TTSProcessor,
14
+ )
15
+ import uuid
16
+
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+ vq_model = RQBottleneckTransformer.load_model(
20
+ "whisper-vq-stoks-v3-7lang-fixed.model"
21
+ ).to(device)
22
+ # tts = TTSProcessor('cpu')
23
+ use_8bit = False
24
+ llm_path = "homebrewltd/Ichigo-llama3.1-s-instruct-v0.3-phase-3"
25
+ tokenizer = AutoTokenizer.from_pretrained(llm_path)
26
+ model_kwargs = {}
27
+ if use_8bit:
28
+ model_kwargs["quantization_config"] = BitsAndBytesConfig(
29
+ load_in_8bit=True,
30
+ llm_int8_enable_fp32_cpu_offload=False,
31
+ llm_int8_has_fp16_weight=False,
32
+ )
33
+ else:
34
+ model_kwargs["torch_dtype"] = torch.bfloat16
35
+ model = AutoModelForCausalLM.from_pretrained(llm_path, **model_kwargs).to(device)
36
+
37
+ @spaces.GPU
38
+ def audio_to_sound_tokens_whisperspeech(audio_path):
39
+ vq_model.ensure_whisper('cuda')
40
+ wav, sr = torchaudio.load(audio_path)
41
+ if sr != 16000:
42
+ wav = torchaudio.functional.resample(wav, sr, 16000)
43
+ with torch.no_grad():
44
+ codes = vq_model.encode_audio(wav.to(device))
45
+ codes = codes[0].cpu().tolist()
46
+
47
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
48
+ return f'<|sound_start|>{result}<|sound_end|>'
49
+
50
+ @spaces.GPU
51
+ def audio_to_sound_tokens_whisperspeech_transcribe(audio_path):
52
+ vq_model.ensure_whisper('cuda')
53
+ wav, sr = torchaudio.load(audio_path)
54
+ if sr != 16000:
55
+ wav = torchaudio.functional.resample(wav, sr, 16000)
56
+ with torch.no_grad():
57
+ codes = vq_model.encode_audio(wav.to(device))
58
+ codes = codes[0].cpu().tolist()
59
+
60
+ result = ''.join(f'<|sound_{num:04d}|>' for num in codes)
61
+ return f'Transcribe the speech in this audio sample:<|sound_start|>{result}<|sound_end|>'
62
+ # print(tokenizer.encode("<|sound_0001|>", add_special_tokens=False))# return the audio tensor
63
+ # print(tokenizer.eos_token)
64
+
65
+ @spaces.GPU
66
+ def text_to_audio_file(text):
67
+ # gen a random id for the audio file
68
+ id = str(uuid.uuid4())
69
+ temp_file = f"./user_audio/{id}_temp_audio.wav"
70
+ text = text
71
+ text_split = "_".join(text.lower().split(" "))
72
+ # remove the last character if it is a period
73
+ if text_split[-1] == ".":
74
+ text_split = text_split[:-1]
75
+ tts = TTSProcessor("cuda")
76
+ tts.convert_text_to_audio_file(text, temp_file)
77
+ # logging.info(f"Saving audio to {temp_file}")
78
+ # torchaudio.save(temp_file, audio.cpu(), sample_rate=24000)
79
+ print(f"Saved audio to {temp_file}")
80
+ return temp_file
81
+
82
+
83
+ @spaces.GPU
84
+ def process_input(audio_file=None):
85
+
86
+ for partial_message in process_audio(audio_file):
87
+ yield partial_message
88
+
89
+
90
+ @spaces.GPU
91
+ def process_transcribe_input(audio_file=None):
92
+
93
+ for partial_message in process_audio(audio_file, transcript=True):
94
+ yield partial_message
95
+
96
+ class StopOnTokens(StoppingCriteria):
97
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
98
+ # encode </s> token
99
+ stop_ids = [tokenizer.eos_token_id, 128009] # Adjust this based on your model's tokenizer
100
+ for stop_id in stop_ids:
101
+ if input_ids[0][-1] == stop_id:
102
+ return True
103
+ return False
104
+
105
+ @spaces.GPU
106
+ def process_audio(audio_file, transcript=False):
107
+ if audio_file is None:
108
+ raise ValueError("No audio file provided")
109
+
110
+ logging.info(f"Audio file received: {audio_file}")
111
+ logging.info(f"Audio file type: {type(audio_file)}")
112
+
113
+ sound_tokens = audio_to_sound_tokens_whisperspeech_transcribe(audio_file) if transcript else audio_to_sound_tokens_whisperspeech(audio_file)
114
+ logging.info("Sound tokens generated successfully")
115
+ # logging.info(f"audio_file: {audio_file.name}")
116
+ messages = [
117
+ {"role": "user", "content": sound_tokens},
118
+ ]
119
+
120
+ stop = StopOnTokens()
121
+ input_str = tokenizer.apply_chat_template(messages, tokenize=False)
122
+ input_ids = tokenizer.encode(input_str, return_tensors="pt")
123
+ input_ids = input_ids.to(model.device)
124
+
125
+ streamer = TextIteratorStreamer(tokenizer, timeout=10., skip_prompt=True, skip_special_tokens=True)
126
+ generation_kwargs = dict(
127
+ input_ids=input_ids,
128
+ streamer=streamer,
129
+ max_new_tokens=1024,
130
+ do_sample=False,
131
+ stopping_criteria=StoppingCriteriaList([stop])
132
+ )
133
+
134
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
135
+ thread.start()
136
+
137
+ partial_message = ""
138
+ for new_token in streamer:
139
+ partial_message += new_token
140
+ if tokenizer.eos_token in partial_message:
141
+ break
142
+ partial_message = partial_message.replace("assistant\n\n", "")
143
+ yield partial_message
144
+ # def stop_generation():
145
+ # # This is a placeholder. Implement actual stopping logic here if needed.
146
+ # return "Generation stopped.", gr.Button.update(interactive=False)
147
+ # take all the examples from the examples folder
148
+ good_examples = []
149
+ for file in os.listdir("./examples"):
150
+ if file.endswith(".wav"):
151
+ good_examples.append([f"./examples/{file}"])
152
+ bad_examples = []
153
+ for file in os.listdir("./bad_examples"):
154
+ if file.endswith(".wav"):
155
+ bad_examples.append([f"./bad_examples/{file}"])
156
+ examples = []
157
+ examples.extend(good_examples)
158
+ examples.extend(bad_examples)
159
+ with gr.Blocks() as iface:
160
+ gr.Markdown("# Ichigo-llama3-s: Llama3.1 with listening capabilities")
161
+ gr.Markdown("Record your voice or upload audio and send it to the model.")
162
+ gr.Markdown("Powered by [Homebrew Ltd](https://homebrew.ltd/) | [Read our blog post](https://homebrew.ltd/blog/llama3-just-got-ears)")
163
+
164
+ with gr.Row():
165
+ input_type = gr.Radio(["text", "audio"], label="Input Type", value="audio")
166
+ text_input = gr.Textbox(label="Send", visible=False)
167
+ audio_input = gr.Audio(label="Audio", type="filepath", visible=True)
168
+ # audio_output = gr.Audio(label="Converted Audio", type="filepath", visible=False)
169
+
170
+ convert_button = gr.Button("Convert to Audio", visible=False)
171
+ submit_button = gr.Button("Send")
172
+ transcrip_button = gr.Button("Make Model Transcribe the audio")
173
+
174
+ text_output = gr.Textbox(label="Generated Text")
175
+
176
+ def update_visibility(input_type):
177
+ return (gr.update(visible=input_type == "text"),
178
+ gr.update(visible=input_type == "text"))
179
+ def convert_and_display(text):
180
+ audio_file = text_to_audio_file(text)
181
+ return audio_file
182
+ def process_example(file_path):
183
+ return update_visibility("audio")
184
+ input_type.change(
185
+ update_visibility,
186
+ inputs=[input_type],
187
+ outputs=[text_input, convert_button]
188
+ )
189
+
190
+ convert_button.click(
191
+ convert_and_display,
192
+ inputs=[text_input],
193
+ outputs=[audio_input]
194
+ )
195
+
196
+ submit_button.click(
197
+ process_input,
198
+ inputs=[audio_input],
199
+ outputs=[text_output]
200
+ )
201
+ transcrip_button.click(
202
+ process_transcribe_input,
203
+ inputs=[audio_input],
204
+ outputs=[text_output]
205
+ )
206
+
207
+ gr.Examples(examples, inputs=[audio_input])
208
+ iface.queue()
209
+ iface.launch()
210
+ # launch locally
211
+ # iface.launch(server_name="0.0.0.0")
bad_examples/bad-What-is-Love.wav ADDED
Binary file (41.7 kB). View file
 
bad_examples/bad-who-bears-Obama.wav ADDED
Binary file (64.7 kB). View file
 
examples/Can-you-write-a-registration-letter.wav ADDED
Binary file (109 kB). View file
 
examples/Hello.wav ADDED
Binary file (18.6 kB). View file
 
examples/Who-is-Harry-Potter.wav ADDED
Binary file (62.8 kB). View file
 
examples/Write-an-email.wav ADDED
Binary file (45.5 kB). View file
 
examples/codeapythonscript.wav ADDED
Binary file (61 kB). View file
 
examples/generate_3_questions_you_can_ask_an_interviewer.wav ADDED
Binary file (302 kB). View file
 
examples/story.wav ADDED
Binary file (41.5 kB). View file
 
examples/what-is-the-color-of-the-elephant.wav ADDED
Binary file (107 kB). View file
 
examples/what-is-the-color-of-the-ocean.wav ADDED
Binary file (97.4 kB). View file
 
generate_audio.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torchaudio
2
+
3
+ from whisperspeech.pipeline import Pipeline
4
+ import argparse
5
+
6
+ def parse_args():
7
+ parser = argparse.ArgumentParser(description="Convert text to audio.")
8
+ parser.add_argument(
9
+ "--text",
10
+ type=str,
11
+ required=True,
12
+ help="The text to convert to audio.",
13
+ )
14
+ return parser.parse_args()
15
+
16
+ def convert_text_to_audio(pipe: Pipeline, text: str):
17
+ """Convert text to audio.
18
+
19
+ Args:
20
+ pipe (Pipeline): The pipeline to use for text-to-speech.
21
+ text (str): The text to convert to audio.
22
+
23
+ Returns:
24
+ torch.Tensor: The generated audio.
25
+ """
26
+ return pipe.generate(text)
27
+
28
+
29
+ def convert_text_to_audio_file(pipe: Pipeline, text: str, output_path: str):
30
+ """Convert text to audio and save it to a file.
31
+
32
+ Args:
33
+ pipe (Pipeline): The pipeline to use for text-to-speech.
34
+ text (str): The text to convert to audio.
35
+ output_path (str): The path to save the audio file.
36
+ """
37
+ pipe.generate_to_file(output_path, text)
38
+
39
+
40
+ class TTSProcessor:
41
+ def __init__(self, device: str):
42
+ """Initialize the TTS Processor with a specified device."""
43
+ self.pipe = Pipeline(
44
+ s2a_ref="collabora/whisperspeech:s2a-q4-tiny-en+pl.model", device=device
45
+ )
46
+
47
+ def get_reference_voice_embedding(self, path: str):
48
+ """Get the reference voice embedding from the given audio file.
49
+
50
+ Args:
51
+ path (str): The path to the audio file.
52
+ Returns:
53
+ torch.Tensor: The reference voice embedding."""
54
+ return self.pipe.extract_spk_emb(path).cpu()
55
+
56
+ def convert_text_to_audio(self, text: str, speaker=None):
57
+ """Convert text to audio.
58
+
59
+ Args:
60
+ text (str): The text to convert to audio.
61
+
62
+ Returns:
63
+ torch.Tensor: The generated audio.
64
+ """
65
+ return self.pipe.generate(text, speaker=speaker)
66
+
67
+ def convert_text_to_audio_file(self, text: str, output_path: str, speaker=None):
68
+ """Convert text to audio and save it to a file.
69
+
70
+ Args:
71
+ text (str): The text to convert to audio.
72
+ output_path (str): The path to save the audio file.
73
+ """
74
+ self.pipe.generate_to_file(output_path, text, speaker=speaker)
75
+ if __name__ == "__main__":
76
+ args = parse_args()
77
+ processor = TTSProcessor("cuda")
78
+ text = args.text
79
+ text = text.lower()
80
+ text_split = "_".join(text.lower().split(" "))
81
+ # remove the last character if it is a period
82
+ if text_split[-1] == ".":
83
+ text_split = text_split[:-1]
84
+ print(text_split)
85
+ path = f"./examples/{text_split}.wav"
86
+ processor.convert_text_to_audio_file(text, path)
87
+
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openai-whisper==20231117
2
+ IPython
3
+ peft
4
+ huggingface_hub
5
+ matplotlib
6
+ pyarrow
7
+ datasets
8
+ encodec
9
+ soundfile
10
+ gradio==4.39.0
11
+ transformers
12
+ bitsandbytes
13
+ torchvision
14
+ vector_quantize_pytorch
15
+ webdataset
16
+ whisperspeech
17
+ --extra-index-url https://download.pytorch.org/whl/cu121
18
+ torch==2.2.0
19
+ torchaudio==2.2.0
20
+ fsspec==2024.6.1
21
+ anyio==4.4.0
user_audio/0bf62a35-94bb-43f0-9a5f-9691c1691859_temp_audio.wav ADDED
Binary file (147 kB). View file