Plachta commited on
Commit
1c7cc12
·
verified ·
1 Parent(s): c6d0958

Update modules/audio.py

Browse files
Files changed (1) hide show
  1. modules/audio.py +82 -82
modules/audio.py CHANGED
@@ -1,82 +1,82 @@
1
- import numpy as np
2
- import torch
3
- import torch.utils.data
4
- from librosa.filters import mel as librosa_mel_fn
5
- from scipy.io.wavfile import read
6
-
7
- MAX_WAV_VALUE = 32768.0
8
-
9
-
10
- def load_wav(full_path):
11
- sampling_rate, data = read(full_path)
12
- return data, sampling_rate
13
-
14
-
15
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
-
18
-
19
- def dynamic_range_decompression(x, C=1):
20
- return np.exp(x) / C
21
-
22
-
23
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
- return torch.log(torch.clamp(x, min=clip_val) * C)
25
-
26
-
27
- def dynamic_range_decompression_torch(x, C=1):
28
- return torch.exp(x) / C
29
-
30
-
31
- def spectral_normalize_torch(magnitudes):
32
- output = dynamic_range_compression_torch(magnitudes)
33
- return output
34
-
35
-
36
- def spectral_de_normalize_torch(magnitudes):
37
- output = dynamic_range_decompression_torch(magnitudes)
38
- return output
39
-
40
-
41
- mel_basis = {}
42
- hann_window = {}
43
-
44
-
45
- def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
- if torch.min(y) < -1.0:
47
- print("min value is ", torch.min(y))
48
- if torch.max(y) > 1.0:
49
- print("max value is ", torch.max(y))
50
-
51
- global mel_basis, hann_window # pylint: disable=global-statement
52
- if f"{str(fmax)}_{str(y.device)}" not in mel_basis:
53
- mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
- mel_basis[str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
- hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
56
-
57
- y = torch.nn.functional.pad(
58
- y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
- )
60
- y = y.squeeze(1)
61
-
62
- spec = torch.view_as_real(
63
- torch.stft(
64
- y,
65
- n_fft,
66
- hop_length=hop_size,
67
- win_length=win_size,
68
- window=hann_window[str(y.device)],
69
- center=center,
70
- pad_mode="reflect",
71
- normalized=False,
72
- onesided=True,
73
- return_complex=True,
74
- )
75
- )
76
-
77
- spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
-
79
- spec = torch.matmul(mel_basis[str(fmax) + "_" + str(y.device)], spec)
80
- spec = spectral_normalize_torch(spec)
81
-
82
- return spec
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data
4
+ from librosa.filters import mel as librosa_mel_fn
5
+ from scipy.io.wavfile import read
6
+
7
+ MAX_WAV_VALUE = 32768.0
8
+
9
+
10
+ def load_wav(full_path):
11
+ sampling_rate, data = read(full_path)
12
+ return data, sampling_rate
13
+
14
+
15
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
16
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
17
+
18
+
19
+ def dynamic_range_decompression(x, C=1):
20
+ return np.exp(x) / C
21
+
22
+
23
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
24
+ return torch.log(torch.clamp(x, min=clip_val) * C)
25
+
26
+
27
+ def dynamic_range_decompression_torch(x, C=1):
28
+ return torch.exp(x) / C
29
+
30
+
31
+ def spectral_normalize_torch(magnitudes):
32
+ output = dynamic_range_compression_torch(magnitudes)
33
+ return output
34
+
35
+
36
+ def spectral_de_normalize_torch(magnitudes):
37
+ output = dynamic_range_decompression_torch(magnitudes)
38
+ return output
39
+
40
+
41
+ mel_basis = {}
42
+ hann_window = {}
43
+
44
+
45
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
46
+ if torch.min(y) < -1.0:
47
+ print("min value is ", torch.min(y))
48
+ if torch.max(y) > 1.0:
49
+ print("max value is ", torch.max(y))
50
+
51
+ global mel_basis, hann_window # pylint: disable=global-statement
52
+ if f"{str(sampling_rate)}_{str(fmax)}_{str(y.device)}" not in mel_basis:
53
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
54
+ mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)] = torch.from_numpy(mel).float().to(y.device)
55
+ hann_window[str(sampling_rate) + "_" + str(y.device)] = torch.hann_window(win_size).to(y.device)
56
+
57
+ y = torch.nn.functional.pad(
58
+ y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
59
+ )
60
+ y = y.squeeze(1)
61
+
62
+ spec = torch.view_as_real(
63
+ torch.stft(
64
+ y,
65
+ n_fft,
66
+ hop_length=hop_size,
67
+ win_length=win_size,
68
+ window=hann_window[str(sampling_rate) + "_" + str(y.device)],
69
+ center=center,
70
+ pad_mode="reflect",
71
+ normalized=False,
72
+ onesided=True,
73
+ return_complex=True,
74
+ )
75
+ )
76
+
77
+ spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
78
+
79
+ spec = torch.matmul(mel_basis[str(sampling_rate) + "_" + str(fmax) + "_" + str(y.device)], spec)
80
+ spec = spectral_normalize_torch(spec)
81
+
82
+ return spec