candenizkocak commited on
Commit
18b1b5d
·
verified ·
1 Parent(s): 38c63c0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -0
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torchaudio
3
+ from transformers import pipeline
4
+ from datasets import load_dataset, Audio
5
+
6
+ # Load your model
7
+ classifier = pipeline("audio-classification", model="candenizkocak/wav2vec2-base_turkish_gender_classification")
8
+
9
+ # Function to resample audio to 16kHz
10
+ def resample_audio(audio_file, target_sampling_rate=16000):
11
+ waveform, original_sample_rate = torchaudio.load(audio_file)
12
+ if original_sample_rate != target_sampling_rate:
13
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=target_sampling_rate)
14
+ waveform = resampler(waveform)
15
+ return waveform.squeeze().numpy(), target_sampling_rate
16
+
17
+ # Define the prediction function
18
+ def classify_audio(audio_file):
19
+ # Resample the audio to 16kHz
20
+ resampled_audio, _ = resample_audio(audio_file)
21
+
22
+ # Classify the audio
23
+ prediction = classifier(resampled_audio)
24
+
25
+ # Return predictions as a dictionary
26
+ return {entry['label']: entry['score'] for entry in prediction}
27
+
28
+ # Define Gradio interface
29
+ def demo():
30
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
31
+ gr.Markdown("## Turkish Gender Audio Classification")
32
+
33
+ # Input Audio
34
+ with gr.Row():
35
+ audio_input = gr.Audio(type="filepath", label="Input Audio")
36
+
37
+ # Output Labels
38
+ with gr.Row():
39
+ label_output = gr.Label(label="Prediction")
40
+
41
+ # Predict Button
42
+ classify_btn = gr.Button("Classify")
43
+
44
+ # Define the interaction
45
+ classify_btn.click(fn=classify_audio, inputs=audio_input, outputs=label_output)
46
+
47
+ return demo
48
+
49
+ # Launch the demo
50
+ demo().launch()