|
import gradio as gr |
|
import torchaudio |
|
from transformers import pipeline |
|
from datasets import load_dataset, Audio |
|
|
|
|
|
classifier = pipeline("audio-classification", model="candenizkocak/wav2vec2-base_turkish_gender_classification") |
|
|
|
|
|
def resample_audio(audio_file, target_sampling_rate=16000): |
|
waveform, original_sample_rate = torchaudio.load(audio_file) |
|
if original_sample_rate != target_sampling_rate: |
|
resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sampling_rate) |
|
waveform = resampler(waveform) |
|
return waveform.squeeze().numpy(), target_sampling_rate |
|
|
|
|
|
def classify_audio(audio_file): |
|
|
|
resampled_audio, _ = resample_audio(audio_file) |
|
|
|
|
|
prediction = classifier(resampled_audio) |
|
|
|
|
|
return {entry['label']: entry['score'] for entry in prediction} |
|
|
|
|
|
def demo(): |
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("## Turkish Gender Audio Classification") |
|
|
|
|
|
with gr.Row(): |
|
audio_input = gr.Audio(type="filepath", label="Input Audio") |
|
|
|
|
|
with gr.Row(): |
|
label_output = gr.Label(label="Prediction") |
|
|
|
|
|
classify_btn = gr.Button("Classify") |
|
|
|
|
|
classify_btn.click(fn=classify_audio, inputs=audio_input, outputs=label_output) |
|
|
|
return demo |
|
|
|
|
|
demo().launch() |