Xingyu Bian commited on
Commit
6cd5da8
·
1 Parent(s): 9437579

updated diarization pipeline and UI changes

Browse files
Files changed (2) hide show
  1. app.py +30 -10
  2. sample1.wav +0 -0
app.py CHANGED
@@ -33,25 +33,26 @@ pipe = pipeline(
33
  device=device,
34
  )
35
 
36
- # diarization pipeline (renamed to avoid conflict)
37
  diarization_pipeline = Pipeline.from_pretrained(
38
- "pyannote/speaker-diarization-3.0", use_auth_token=os.getenv("HF_KEY")
39
  )
40
 
41
 
 
42
  def diarization_info(res):
43
  starts = []
44
  ends = []
45
  speakers = []
46
 
47
- for segment, track, _ in res.itertracks(yield_label=True):
48
  starts.append(segment.start)
49
  ends.append(segment.end)
50
- speakers.append(track)
51
 
52
  return starts, ends, speakers
53
 
54
 
 
55
  def plot_diarization(starts, ends, speakers):
56
  fig = go.Figure()
57
 
@@ -83,13 +84,23 @@ def plot_diarization(starts, ends, speakers):
83
  return fig
84
 
85
 
 
 
 
 
 
 
 
 
 
86
  def transcribe_diarize(audio):
87
  sr, data = audio
88
  processed_data = np.array(data).astype(np.float32) / 32767.0
89
  waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
90
 
91
- # results from the pipeline
92
- transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"]
 
93
  diarization_res = diarization_pipeline(
94
  {"waveform": waveform_tensor, "sample_rate": sr}
95
  )
@@ -97,10 +108,18 @@ def transcribe_diarize(audio):
97
  # Get diarization information
98
  starts, ends, speakers = diarization_info(diarization_res)
99
 
 
 
 
 
 
 
 
 
100
  # Plot diarization
101
  diarization_plot = plot_diarization(starts, ends, speakers)
102
 
103
- return transcription_res, diarization_res, diarization_plot
104
 
105
 
106
  # creating the gradio interface
@@ -109,11 +128,12 @@ demo = gr.Interface(
109
  inputs=gr.Audio(sources=["upload", "microphone"]),
110
  outputs=[
111
  gr.Textbox(lines=3, label="Text Transcription"),
112
- gr.Textbox(label="Speaker Diarization"),
113
- gr.Plot(),
114
  ],
 
115
  title="Automatic Speech Recognition with Diarization 🗣️",
116
- description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file 🎙️",
117
  )
118
 
119
 
 
33
  device=device,
34
  )
35
 
 
36
  diarization_pipeline = Pipeline.from_pretrained(
37
+ "pyannote/speaker-diarization-3.1", use_auth_token=os.getenv("HF_KEY")
38
  )
39
 
40
 
41
+ # returns diarization info such as segment start and end times, and speaker id
42
  def diarization_info(res):
43
  starts = []
44
  ends = []
45
  speakers = []
46
 
47
+ for segment, _, speaker in res.itertracks(yield_label=True):
48
  starts.append(segment.start)
49
  ends.append(segment.end)
50
+ speakers.append(speaker)
51
 
52
  return starts, ends, speakers
53
 
54
 
55
+ # plot diarization results on a graph
56
  def plot_diarization(starts, ends, speakers):
57
  fig = go.Figure()
58
 
 
84
  return fig
85
 
86
 
87
+ def transcribe(sr, data):
88
+ processed_data = np.array(data).astype(np.float32) / 32767.0
89
+
90
+ # results from the pipeline
91
+ transcription_res = pipe({"sampling_rate": sr, "raw": processed_data})["text"]
92
+
93
+ return transcription_res
94
+
95
+
96
  def transcribe_diarize(audio):
97
  sr, data = audio
98
  processed_data = np.array(data).astype(np.float32) / 32767.0
99
  waveform_tensor = torch.tensor(processed_data[np.newaxis, :])
100
 
101
+ transcription_res = transcribe(sr, data)
102
+
103
+ # results from the diarization pipeline
104
  diarization_res = diarization_pipeline(
105
  {"waveform": waveform_tensor, "sample_rate": sr}
106
  )
 
108
  # Get diarization information
109
  starts, ends, speakers = diarization_info(diarization_res)
110
 
111
+ # results from the transcription pipeline
112
+ diarized_transcription = ""
113
+
114
+ # Get transcription results for each speaker segment
115
+ for start_time, end_time, speaker_id in zip(starts, ends, speakers):
116
+ segment = data[int(start_time * sr) : int(end_time * sr)]
117
+ diarized_transcription += f"{speaker_id} {round(start_time, 2)}:{round(end_time, 2)} \t {transcribe(sr, segment)}\n"
118
+
119
  # Plot diarization
120
  diarization_plot = plot_diarization(starts, ends, speakers)
121
 
122
+ return transcription_res, diarized_transcription, diarization_plot
123
 
124
 
125
  # creating the gradio interface
 
128
  inputs=gr.Audio(sources=["upload", "microphone"]),
129
  outputs=[
130
  gr.Textbox(lines=3, label="Text Transcription"),
131
+ gr.Textbox(label="Diarized Transcription"),
132
+ gr.Plot(label="Visualization"),
133
  ],
134
+ examples=["sample1.wav"],
135
  title="Automatic Speech Recognition with Diarization 🗣️",
136
+ description="Transcribe your speech to text with distilled whisper and diarization with pyannote. Get started by recording from your mic or uploading an audio file (.wav) 🎙️",
137
  )
138
 
139
 
sample1.wav ADDED
Binary file (438 kB). View file