Plachta commited on
Commit
a1e9282
·
verified ·
1 Parent(s): c7222eb

Update modules/commons.py

Browse files
Files changed (1) hide show
  1. modules/commons.py +490 -490
modules/commons.py CHANGED
@@ -1,490 +1,490 @@
1
- import math
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from torch.nn import functional as F
6
- from munch import Munch
7
- import json
8
-
9
-
10
- class AttrDict(dict):
11
- def __init__(self, *args, **kwargs):
12
- super(AttrDict, self).__init__(*args, **kwargs)
13
- self.__dict__ = self
14
-
15
-
16
- def init_weights(m, mean=0.0, std=0.01):
17
- classname = m.__class__.__name__
18
- if classname.find("Conv") != -1:
19
- m.weight.data.normal_(mean, std)
20
-
21
-
22
- def get_padding(kernel_size, dilation=1):
23
- return int((kernel_size * dilation - dilation) / 2)
24
-
25
-
26
- def convert_pad_shape(pad_shape):
27
- l = pad_shape[::-1]
28
- pad_shape = [item for sublist in l for item in sublist]
29
- return pad_shape
30
-
31
-
32
- def intersperse(lst, item):
33
- result = [item] * (len(lst) * 2 + 1)
34
- result[1::2] = lst
35
- return result
36
-
37
-
38
- def kl_divergence(m_p, logs_p, m_q, logs_q):
39
- """KL(P||Q)"""
40
- kl = (logs_q - logs_p) - 0.5
41
- kl += (
42
- 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
43
- )
44
- return kl
45
-
46
-
47
- def rand_gumbel(shape):
48
- """Sample from the Gumbel distribution, protect from overflows."""
49
- uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
50
- return -torch.log(-torch.log(uniform_samples))
51
-
52
-
53
- def rand_gumbel_like(x):
54
- g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
55
- return g
56
-
57
-
58
- def slice_segments(x, ids_str, segment_size=4):
59
- ret = torch.zeros_like(x[:, :, :segment_size])
60
- for i in range(x.size(0)):
61
- idx_str = ids_str[i]
62
- idx_end = idx_str + segment_size
63
- ret[i] = x[i, :, idx_str:idx_end]
64
- return ret
65
-
66
-
67
- def slice_segments_audio(x, ids_str, segment_size=4):
68
- ret = torch.zeros_like(x[:, :segment_size])
69
- for i in range(x.size(0)):
70
- idx_str = ids_str[i]
71
- idx_end = idx_str + segment_size
72
- ret[i] = x[i, idx_str:idx_end]
73
- return ret
74
-
75
-
76
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
77
- b, d, t = x.size()
78
- if x_lengths is None:
79
- x_lengths = t
80
- ids_str_max = x_lengths - segment_size + 1
81
- ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
82
- dtype=torch.long
83
- )
84
- ret = slice_segments(x, ids_str, segment_size)
85
- return ret, ids_str
86
-
87
-
88
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
89
- position = torch.arange(length, dtype=torch.float)
90
- num_timescales = channels // 2
91
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
92
- num_timescales - 1
93
- )
94
- inv_timescales = min_timescale * torch.exp(
95
- torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
96
- )
97
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
98
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
99
- signal = F.pad(signal, [0, 0, 0, channels % 2])
100
- signal = signal.view(1, channels, length)
101
- return signal
102
-
103
-
104
- def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
105
- b, channels, length = x.size()
106
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
107
- return x + signal.to(dtype=x.dtype, device=x.device)
108
-
109
-
110
- def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
111
- b, channels, length = x.size()
112
- signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
113
- return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
114
-
115
-
116
- def subsequent_mask(length):
117
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
118
- return mask
119
-
120
-
121
- @torch.jit.script
122
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
123
- n_channels_int = n_channels[0]
124
- in_act = input_a + input_b
125
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
126
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
127
- acts = t_act * s_act
128
- return acts
129
-
130
-
131
- def convert_pad_shape(pad_shape):
132
- l = pad_shape[::-1]
133
- pad_shape = [item for sublist in l for item in sublist]
134
- return pad_shape
135
-
136
-
137
- def shift_1d(x):
138
- x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
139
- return x
140
-
141
-
142
- def sequence_mask(length, max_length=None):
143
- if max_length is None:
144
- max_length = length.max()
145
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
146
- return x.unsqueeze(0) < length.unsqueeze(1)
147
-
148
-
149
- def avg_with_mask(x, mask):
150
- assert mask.dtype == torch.float, "Mask should be float"
151
-
152
- if mask.ndim == 2:
153
- mask = mask.unsqueeze(1)
154
-
155
- if mask.shape[1] == 1:
156
- mask = mask.expand_as(x)
157
-
158
- return (x * mask).sum() / mask.sum()
159
-
160
-
161
- def generate_path(duration, mask):
162
- """
163
- duration: [b, 1, t_x]
164
- mask: [b, 1, t_y, t_x]
165
- """
166
- device = duration.device
167
-
168
- b, _, t_y, t_x = mask.shape
169
- cum_duration = torch.cumsum(duration, -1)
170
-
171
- cum_duration_flat = cum_duration.view(b * t_x)
172
- path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
173
- path = path.view(b, t_x, t_y)
174
- path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
175
- path = path.unsqueeze(1).transpose(2, 3) * mask
176
- return path
177
-
178
-
179
- def clip_grad_value_(parameters, clip_value, norm_type=2):
180
- if isinstance(parameters, torch.Tensor):
181
- parameters = [parameters]
182
- parameters = list(filter(lambda p: p.grad is not None, parameters))
183
- norm_type = float(norm_type)
184
- if clip_value is not None:
185
- clip_value = float(clip_value)
186
-
187
- total_norm = 0
188
- for p in parameters:
189
- param_norm = p.grad.data.norm(norm_type)
190
- total_norm += param_norm.item() ** norm_type
191
- if clip_value is not None:
192
- p.grad.data.clamp_(min=-clip_value, max=clip_value)
193
- total_norm = total_norm ** (1.0 / norm_type)
194
- return total_norm
195
-
196
-
197
- def log_norm(x, mean=-4, std=4, dim=2):
198
- """
199
- normalized log mel -> mel -> norm -> log(norm)
200
- """
201
- x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
202
- return x
203
-
204
-
205
- def load_F0_models(path):
206
- # load F0 model
207
- from .JDC.model import JDCNet
208
-
209
- F0_model = JDCNet(num_class=1, seq_len=192)
210
- params = torch.load(path, map_location="cpu")["net"]
211
- F0_model.load_state_dict(params)
212
- _ = F0_model.train()
213
-
214
- return F0_model
215
-
216
-
217
- def modify_w2v_forward(self, output_layer=15):
218
- """
219
- change forward method of w2v encoder to get its intermediate layer output
220
- :param self:
221
- :param layer:
222
- :return:
223
- """
224
- from transformers.modeling_outputs import BaseModelOutput
225
-
226
- def forward(
227
- hidden_states,
228
- attention_mask=None,
229
- output_attentions=False,
230
- output_hidden_states=False,
231
- return_dict=True,
232
- ):
233
- all_hidden_states = () if output_hidden_states else None
234
- all_self_attentions = () if output_attentions else None
235
-
236
- conv_attention_mask = attention_mask
237
- if attention_mask is not None:
238
- # make sure padded tokens output 0
239
- hidden_states = hidden_states.masked_fill(
240
- ~attention_mask.bool().unsqueeze(-1), 0.0
241
- )
242
-
243
- # extend attention_mask
244
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(
245
- dtype=hidden_states.dtype
246
- )
247
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
248
- attention_mask = attention_mask.expand(
249
- attention_mask.shape[0],
250
- 1,
251
- attention_mask.shape[-1],
252
- attention_mask.shape[-1],
253
- )
254
-
255
- hidden_states = self.dropout(hidden_states)
256
-
257
- if self.embed_positions is not None:
258
- relative_position_embeddings = self.embed_positions(hidden_states)
259
- else:
260
- relative_position_embeddings = None
261
-
262
- deepspeed_zero3_is_enabled = False
263
-
264
- for i, layer in enumerate(self.layers):
265
- if output_hidden_states:
266
- all_hidden_states = all_hidden_states + (hidden_states,)
267
-
268
- # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
269
- dropout_probability = torch.rand([])
270
-
271
- skip_the_layer = (
272
- True
273
- if self.training and (dropout_probability < self.config.layerdrop)
274
- else False
275
- )
276
- if not skip_the_layer or deepspeed_zero3_is_enabled:
277
- # under deepspeed zero3 all gpus must run in sync
278
- if self.gradient_checkpointing and self.training:
279
- layer_outputs = self._gradient_checkpointing_func(
280
- layer.__call__,
281
- hidden_states,
282
- attention_mask,
283
- relative_position_embeddings,
284
- output_attentions,
285
- conv_attention_mask,
286
- )
287
- else:
288
- layer_outputs = layer(
289
- hidden_states,
290
- attention_mask=attention_mask,
291
- relative_position_embeddings=relative_position_embeddings,
292
- output_attentions=output_attentions,
293
- conv_attention_mask=conv_attention_mask,
294
- )
295
- hidden_states = layer_outputs[0]
296
-
297
- if skip_the_layer:
298
- layer_outputs = (None, None)
299
-
300
- if output_attentions:
301
- all_self_attentions = all_self_attentions + (layer_outputs[1],)
302
-
303
- if i == output_layer - 1:
304
- break
305
-
306
- if output_hidden_states:
307
- all_hidden_states = all_hidden_states + (hidden_states,)
308
-
309
- if not return_dict:
310
- return tuple(
311
- v
312
- for v in [hidden_states, all_hidden_states, all_self_attentions]
313
- if v is not None
314
- )
315
- return BaseModelOutput(
316
- last_hidden_state=hidden_states,
317
- hidden_states=all_hidden_states,
318
- attentions=all_self_attentions,
319
- )
320
-
321
- return forward
322
-
323
-
324
- MATPLOTLIB_FLAG = False
325
-
326
-
327
- def plot_spectrogram_to_numpy(spectrogram):
328
- global MATPLOTLIB_FLAG
329
- if not MATPLOTLIB_FLAG:
330
- import matplotlib
331
- import logging
332
-
333
- matplotlib.use("Agg")
334
- MATPLOTLIB_FLAG = True
335
- mpl_logger = logging.getLogger("matplotlib")
336
- mpl_logger.setLevel(logging.WARNING)
337
- import matplotlib.pylab as plt
338
- import numpy as np
339
-
340
- fig, ax = plt.subplots(figsize=(10, 2))
341
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
342
- plt.colorbar(im, ax=ax)
343
- plt.xlabel("Frames")
344
- plt.ylabel("Channels")
345
- plt.tight_layout()
346
-
347
- fig.canvas.draw()
348
- data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
349
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
350
- plt.close()
351
- return data
352
-
353
-
354
- def normalize_f0(f0_sequence):
355
- # Remove unvoiced frames (replace with -1)
356
- voiced_indices = np.where(f0_sequence > 0)[0]
357
- f0_voiced = f0_sequence[voiced_indices]
358
-
359
- # Convert to log scale
360
- log_f0 = np.log2(f0_voiced)
361
-
362
- # Calculate mean and standard deviation
363
- mean_f0 = np.mean(log_f0)
364
- std_f0 = np.std(log_f0)
365
-
366
- # Normalize the F0 sequence
367
- normalized_f0 = (log_f0 - mean_f0) / std_f0
368
-
369
- # Create the normalized F0 sequence with unvoiced frames
370
- normalized_sequence = np.zeros_like(f0_sequence)
371
- normalized_sequence[voiced_indices] = normalized_f0
372
- normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
373
-
374
- return normalized_sequence
375
-
376
-
377
- def build_model(args, stage="DiT"):
378
- if stage == "DiT":
379
- from modules.flow_matching import CFM
380
- from modules.length_regulator import InterpolateRegulator
381
-
382
- length_regulator = InterpolateRegulator(
383
- channels=args.length_regulator.channels,
384
- sampling_ratios=args.length_regulator.sampling_ratios,
385
- is_discrete=args.length_regulator.is_discrete,
386
- codebook_size=args.length_regulator.content_codebook_size,
387
- token_dropout_prob=args.length_regulator.token_dropout_prob if hasattr(args.length_regulator, "token_dropout_prob") else 0.0,
388
- token_dropout_range=args.length_regulator.token_dropout_range if hasattr(args.length_regulator, "token_dropout_range") else 0.0,
389
- n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
390
- quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
391
- f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
392
- n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
393
- )
394
- cfm = CFM(args)
395
- nets = Munch(
396
- cfm=cfm,
397
- length_regulator=length_regulator,
398
- )
399
- elif stage == 'codec':
400
- from dac.model.dac import Encoder
401
- from modules.quantize import (
402
- FAquantizer,
403
- )
404
-
405
- encoder = Encoder(
406
- d_model=args.DAC.encoder_dim,
407
- strides=args.DAC.encoder_rates,
408
- d_latent=1024,
409
- causal=args.causal,
410
- lstm=args.lstm,
411
- )
412
-
413
- quantizer = FAquantizer(
414
- in_dim=1024,
415
- n_p_codebooks=1,
416
- n_c_codebooks=args.n_c_codebooks,
417
- n_t_codebooks=2,
418
- n_r_codebooks=3,
419
- codebook_size=1024,
420
- codebook_dim=8,
421
- quantizer_dropout=0.5,
422
- causal=args.causal,
423
- separate_prosody_encoder=args.separate_prosody_encoder,
424
- timbre_norm=args.timbre_norm,
425
- )
426
-
427
- nets = Munch(
428
- encoder=encoder,
429
- quantizer=quantizer,
430
- )
431
- else:
432
- raise ValueError(f"Unknown stage: {stage}")
433
-
434
- return nets
435
-
436
-
437
- def load_checkpoint(
438
- model,
439
- optimizer,
440
- path,
441
- load_only_params=True,
442
- ignore_modules=[],
443
- is_distributed=False,
444
- ):
445
- state = torch.load(path, map_location="cpu")
446
- params = state["net"]
447
- for key in model:
448
- if key in params and key not in ignore_modules:
449
- if not is_distributed:
450
- # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
451
- for k in list(params[key].keys()):
452
- if k.startswith("module."):
453
- params[key][k[len("module.") :]] = params[key][k]
454
- del params[key][k]
455
- model_state_dict = model[key].state_dict()
456
- # 过滤出形状匹配的键值对
457
- filtered_state_dict = {
458
- k: v
459
- for k, v in params[key].items()
460
- if k in model_state_dict and v.shape == model_state_dict[k].shape
461
- }
462
- skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
463
- if skipped_keys:
464
- print(
465
- f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
466
- )
467
- print("%s loaded" % key)
468
- model[key].load_state_dict(filtered_state_dict, strict=False)
469
- _ = [model[key].eval() for key in model]
470
-
471
- if not load_only_params:
472
- epoch = state["epoch"] + 1
473
- iters = state["iters"]
474
- optimizer.load_state_dict(state["optimizer"])
475
- optimizer.load_scheduler_state_dict(state["scheduler"])
476
-
477
- else:
478
- epoch = 0
479
- iters = 0
480
-
481
- return model, optimizer, epoch, iters
482
-
483
-
484
- def recursive_munch(d):
485
- if isinstance(d, dict):
486
- return Munch((k, recursive_munch(v)) for k, v in d.items())
487
- elif isinstance(d, list):
488
- return [recursive_munch(v) for v in d]
489
- else:
490
- return d
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from munch import Munch
7
+ import json
8
+
9
+
10
+ class AttrDict(dict):
11
+ def __init__(self, *args, **kwargs):
12
+ super(AttrDict, self).__init__(*args, **kwargs)
13
+ self.__dict__ = self
14
+
15
+
16
+ def init_weights(m, mean=0.0, std=0.01):
17
+ classname = m.__class__.__name__
18
+ if classname.find("Conv") != -1:
19
+ m.weight.data.normal_(mean, std)
20
+
21
+
22
+ def get_padding(kernel_size, dilation=1):
23
+ return int((kernel_size * dilation - dilation) / 2)
24
+
25
+
26
+ def convert_pad_shape(pad_shape):
27
+ l = pad_shape[::-1]
28
+ pad_shape = [item for sublist in l for item in sublist]
29
+ return pad_shape
30
+
31
+
32
+ def intersperse(lst, item):
33
+ result = [item] * (len(lst) * 2 + 1)
34
+ result[1::2] = lst
35
+ return result
36
+
37
+
38
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
39
+ """KL(P||Q)"""
40
+ kl = (logs_q - logs_p) - 0.5
41
+ kl += (
42
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
43
+ )
44
+ return kl
45
+
46
+
47
+ def rand_gumbel(shape):
48
+ """Sample from the Gumbel distribution, protect from overflows."""
49
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
50
+ return -torch.log(-torch.log(uniform_samples))
51
+
52
+
53
+ def rand_gumbel_like(x):
54
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
55
+ return g
56
+
57
+
58
+ def slice_segments(x, ids_str, segment_size=4):
59
+ ret = torch.zeros_like(x[:, :, :segment_size])
60
+ for i in range(x.size(0)):
61
+ idx_str = ids_str[i]
62
+ idx_end = idx_str + segment_size
63
+ ret[i] = x[i, :, idx_str:idx_end]
64
+ return ret
65
+
66
+
67
+ def slice_segments_audio(x, ids_str, segment_size=4):
68
+ ret = torch.zeros_like(x[:, :segment_size])
69
+ for i in range(x.size(0)):
70
+ idx_str = ids_str[i]
71
+ idx_end = idx_str + segment_size
72
+ ret[i] = x[i, idx_str:idx_end]
73
+ return ret
74
+
75
+
76
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
77
+ b, d, t = x.size()
78
+ if x_lengths is None:
79
+ x_lengths = t
80
+ ids_str_max = x_lengths - segment_size + 1
81
+ ids_str = ((torch.rand([b]).to(device=x.device) * ids_str_max).clip(0)).to(
82
+ dtype=torch.long
83
+ )
84
+ ret = slice_segments(x, ids_str, segment_size)
85
+ return ret, ids_str
86
+
87
+
88
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
89
+ position = torch.arange(length, dtype=torch.float)
90
+ num_timescales = channels // 2
91
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
92
+ num_timescales - 1
93
+ )
94
+ inv_timescales = min_timescale * torch.exp(
95
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
96
+ )
97
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
98
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
99
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
100
+ signal = signal.view(1, channels, length)
101
+ return signal
102
+
103
+
104
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
105
+ b, channels, length = x.size()
106
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
107
+ return x + signal.to(dtype=x.dtype, device=x.device)
108
+
109
+
110
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
111
+ b, channels, length = x.size()
112
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
113
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
114
+
115
+
116
+ def subsequent_mask(length):
117
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
118
+ return mask
119
+
120
+
121
+ @torch.jit.script
122
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
123
+ n_channels_int = n_channels[0]
124
+ in_act = input_a + input_b
125
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
126
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
127
+ acts = t_act * s_act
128
+ return acts
129
+
130
+
131
+ def convert_pad_shape(pad_shape):
132
+ l = pad_shape[::-1]
133
+ pad_shape = [item for sublist in l for item in sublist]
134
+ return pad_shape
135
+
136
+
137
+ def shift_1d(x):
138
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
139
+ return x
140
+
141
+
142
+ def sequence_mask(length, max_length=None):
143
+ if max_length is None:
144
+ max_length = length.max()
145
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
146
+ return x.unsqueeze(0) < length.unsqueeze(1)
147
+
148
+
149
+ def avg_with_mask(x, mask):
150
+ assert mask.dtype == torch.float, "Mask should be float"
151
+
152
+ if mask.ndim == 2:
153
+ mask = mask.unsqueeze(1)
154
+
155
+ if mask.shape[1] == 1:
156
+ mask = mask.expand_as(x)
157
+
158
+ return (x * mask).sum() / mask.sum()
159
+
160
+
161
+ def generate_path(duration, mask):
162
+ """
163
+ duration: [b, 1, t_x]
164
+ mask: [b, 1, t_y, t_x]
165
+ """
166
+ device = duration.device
167
+
168
+ b, _, t_y, t_x = mask.shape
169
+ cum_duration = torch.cumsum(duration, -1)
170
+
171
+ cum_duration_flat = cum_duration.view(b * t_x)
172
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
173
+ path = path.view(b, t_x, t_y)
174
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
175
+ path = path.unsqueeze(1).transpose(2, 3) * mask
176
+ return path
177
+
178
+
179
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
180
+ if isinstance(parameters, torch.Tensor):
181
+ parameters = [parameters]
182
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
183
+ norm_type = float(norm_type)
184
+ if clip_value is not None:
185
+ clip_value = float(clip_value)
186
+
187
+ total_norm = 0
188
+ for p in parameters:
189
+ param_norm = p.grad.data.norm(norm_type)
190
+ total_norm += param_norm.item() ** norm_type
191
+ if clip_value is not None:
192
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
193
+ total_norm = total_norm ** (1.0 / norm_type)
194
+ return total_norm
195
+
196
+
197
+ def log_norm(x, mean=-4, std=4, dim=2):
198
+ """
199
+ normalized log mel -> mel -> norm -> log(norm)
200
+ """
201
+ x = torch.log(torch.exp(x * std + mean).norm(dim=dim))
202
+ return x
203
+
204
+
205
+ def load_F0_models(path):
206
+ # load F0 model
207
+ from .JDC.model import JDCNet
208
+
209
+ F0_model = JDCNet(num_class=1, seq_len=192)
210
+ params = torch.load(path, map_location="cpu")["net"]
211
+ F0_model.load_state_dict(params)
212
+ _ = F0_model.train()
213
+
214
+ return F0_model
215
+
216
+
217
+ def modify_w2v_forward(self, output_layer=15):
218
+ """
219
+ change forward method of w2v encoder to get its intermediate layer output
220
+ :param self:
221
+ :param layer:
222
+ :return:
223
+ """
224
+ from transformers.modeling_outputs import BaseModelOutput
225
+
226
+ def forward(
227
+ hidden_states,
228
+ attention_mask=None,
229
+ output_attentions=False,
230
+ output_hidden_states=False,
231
+ return_dict=True,
232
+ ):
233
+ all_hidden_states = () if output_hidden_states else None
234
+ all_self_attentions = () if output_attentions else None
235
+
236
+ conv_attention_mask = attention_mask
237
+ if attention_mask is not None:
238
+ # make sure padded tokens output 0
239
+ hidden_states = hidden_states.masked_fill(
240
+ ~attention_mask.bool().unsqueeze(-1), 0.0
241
+ )
242
+
243
+ # extend attention_mask
244
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(
245
+ dtype=hidden_states.dtype
246
+ )
247
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
248
+ attention_mask = attention_mask.expand(
249
+ attention_mask.shape[0],
250
+ 1,
251
+ attention_mask.shape[-1],
252
+ attention_mask.shape[-1],
253
+ )
254
+
255
+ hidden_states = self.dropout(hidden_states)
256
+
257
+ if self.embed_positions is not None:
258
+ relative_position_embeddings = self.embed_positions(hidden_states)
259
+ else:
260
+ relative_position_embeddings = None
261
+
262
+ deepspeed_zero3_is_enabled = False
263
+
264
+ for i, layer in enumerate(self.layers):
265
+ if output_hidden_states:
266
+ all_hidden_states = all_hidden_states + (hidden_states,)
267
+
268
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
269
+ dropout_probability = torch.rand([])
270
+
271
+ skip_the_layer = (
272
+ True
273
+ if self.training and (dropout_probability < self.config.layerdrop)
274
+ else False
275
+ )
276
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
277
+ # under deepspeed zero3 all gpus must run in sync
278
+ if self.gradient_checkpointing and self.training:
279
+ layer_outputs = self._gradient_checkpointing_func(
280
+ layer.__call__,
281
+ hidden_states,
282
+ attention_mask,
283
+ relative_position_embeddings,
284
+ output_attentions,
285
+ conv_attention_mask,
286
+ )
287
+ else:
288
+ layer_outputs = layer(
289
+ hidden_states,
290
+ attention_mask=attention_mask,
291
+ relative_position_embeddings=relative_position_embeddings,
292
+ output_attentions=output_attentions,
293
+ conv_attention_mask=conv_attention_mask,
294
+ )
295
+ hidden_states = layer_outputs[0]
296
+
297
+ if skip_the_layer:
298
+ layer_outputs = (None, None)
299
+
300
+ if output_attentions:
301
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
302
+
303
+ if i == output_layer - 1:
304
+ break
305
+
306
+ if output_hidden_states:
307
+ all_hidden_states = all_hidden_states + (hidden_states,)
308
+
309
+ if not return_dict:
310
+ return tuple(
311
+ v
312
+ for v in [hidden_states, all_hidden_states, all_self_attentions]
313
+ if v is not None
314
+ )
315
+ return BaseModelOutput(
316
+ last_hidden_state=hidden_states,
317
+ hidden_states=all_hidden_states,
318
+ attentions=all_self_attentions,
319
+ )
320
+
321
+ return forward
322
+
323
+
324
+ MATPLOTLIB_FLAG = False
325
+
326
+
327
+ def plot_spectrogram_to_numpy(spectrogram):
328
+ global MATPLOTLIB_FLAG
329
+ if not MATPLOTLIB_FLAG:
330
+ import matplotlib
331
+ import logging
332
+
333
+ matplotlib.use("Agg")
334
+ MATPLOTLIB_FLAG = True
335
+ mpl_logger = logging.getLogger("matplotlib")
336
+ mpl_logger.setLevel(logging.WARNING)
337
+ import matplotlib.pylab as plt
338
+ import numpy as np
339
+
340
+ fig, ax = plt.subplots(figsize=(10, 2))
341
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
342
+ plt.colorbar(im, ax=ax)
343
+ plt.xlabel("Frames")
344
+ plt.ylabel("Channels")
345
+ plt.tight_layout()
346
+
347
+ fig.canvas.draw()
348
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
349
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
350
+ plt.close()
351
+ return data
352
+
353
+
354
+ def normalize_f0(f0_sequence):
355
+ # Remove unvoiced frames (replace with -1)
356
+ voiced_indices = np.where(f0_sequence > 0)[0]
357
+ f0_voiced = f0_sequence[voiced_indices]
358
+
359
+ # Convert to log scale
360
+ log_f0 = np.log2(f0_voiced)
361
+
362
+ # Calculate mean and standard deviation
363
+ mean_f0 = np.mean(log_f0)
364
+ std_f0 = np.std(log_f0)
365
+
366
+ # Normalize the F0 sequence
367
+ normalized_f0 = (log_f0 - mean_f0) / std_f0
368
+
369
+ # Create the normalized F0 sequence with unvoiced frames
370
+ normalized_sequence = np.zeros_like(f0_sequence)
371
+ normalized_sequence[voiced_indices] = normalized_f0
372
+ normalized_sequence[f0_sequence <= 0] = -1 # Assign -1 to unvoiced frames
373
+
374
+ return normalized_sequence
375
+
376
+
377
+ def build_model(args, stage="DiT"):
378
+ if stage == "DiT":
379
+ from modules.flow_matching import CFM
380
+ from modules.length_regulator import InterpolateRegulator
381
+
382
+ length_regulator = InterpolateRegulator(
383
+ channels=args.length_regulator.channels,
384
+ sampling_ratios=args.length_regulator.sampling_ratios,
385
+ is_discrete=args.length_regulator.is_discrete,
386
+ in_channels=args.length_regulator.in_channels if hasattr(args.length_regulator, "in_channels") else None,
387
+ vector_quantize=args.length_regulator.vector_quantize if hasattr(args.length_regulator, "vector_quantize") else False,
388
+ codebook_size=args.length_regulator.content_codebook_size,
389
+ n_codebooks=args.length_regulator.n_codebooks if hasattr(args.length_regulator, "n_codebooks") else 1,
390
+ quantizer_dropout=args.length_regulator.quantizer_dropout if hasattr(args.length_regulator, "quantizer_dropout") else 0.0,
391
+ f0_condition=args.length_regulator.f0_condition if hasattr(args.length_regulator, "f0_condition") else False,
392
+ n_f0_bins=args.length_regulator.n_f0_bins if hasattr(args.length_regulator, "n_f0_bins") else 512,
393
+ )
394
+ cfm = CFM(args)
395
+ nets = Munch(
396
+ cfm=cfm,
397
+ length_regulator=length_regulator,
398
+ )
399
+ elif stage == 'codec':
400
+ from dac.model.dac import Encoder
401
+ from modules.quantize import (
402
+ FAquantizer,
403
+ )
404
+
405
+ encoder = Encoder(
406
+ d_model=args.DAC.encoder_dim,
407
+ strides=args.DAC.encoder_rates,
408
+ d_latent=1024,
409
+ causal=args.causal,
410
+ lstm=args.lstm,
411
+ )
412
+
413
+ quantizer = FAquantizer(
414
+ in_dim=1024,
415
+ n_p_codebooks=1,
416
+ n_c_codebooks=args.n_c_codebooks,
417
+ n_t_codebooks=2,
418
+ n_r_codebooks=3,
419
+ codebook_size=1024,
420
+ codebook_dim=8,
421
+ quantizer_dropout=0.5,
422
+ causal=args.causal,
423
+ separate_prosody_encoder=args.separate_prosody_encoder,
424
+ timbre_norm=args.timbre_norm,
425
+ )
426
+
427
+ nets = Munch(
428
+ encoder=encoder,
429
+ quantizer=quantizer,
430
+ )
431
+ else:
432
+ raise ValueError(f"Unknown stage: {stage}")
433
+
434
+ return nets
435
+
436
+
437
+ def load_checkpoint(
438
+ model,
439
+ optimizer,
440
+ path,
441
+ load_only_params=True,
442
+ ignore_modules=[],
443
+ is_distributed=False,
444
+ ):
445
+ state = torch.load(path, map_location="cpu")
446
+ params = state["net"]
447
+ for key in model:
448
+ if key in params and key not in ignore_modules:
449
+ if not is_distributed:
450
+ # strip prefix of DDP (module.), create a new OrderedDict that does not contain the prefix
451
+ for k in list(params[key].keys()):
452
+ if k.startswith("module."):
453
+ params[key][k[len("module.") :]] = params[key][k]
454
+ del params[key][k]
455
+ model_state_dict = model[key].state_dict()
456
+ # 过滤出形状匹配的键值对
457
+ filtered_state_dict = {
458
+ k: v
459
+ for k, v in params[key].items()
460
+ if k in model_state_dict and v.shape == model_state_dict[k].shape
461
+ }
462
+ skipped_keys = set(params[key].keys()) - set(filtered_state_dict.keys())
463
+ if skipped_keys:
464
+ print(
465
+ f"Warning: Skipped loading some keys due to shape mismatch: {skipped_keys}"
466
+ )
467
+ print("%s loaded" % key)
468
+ model[key].load_state_dict(filtered_state_dict, strict=False)
469
+ _ = [model[key].eval() for key in model]
470
+
471
+ if not load_only_params:
472
+ epoch = state["epoch"] + 1
473
+ iters = state["iters"]
474
+ optimizer.load_state_dict(state["optimizer"])
475
+ optimizer.load_scheduler_state_dict(state["scheduler"])
476
+
477
+ else:
478
+ epoch = 0
479
+ iters = 0
480
+
481
+ return model, optimizer, epoch, iters
482
+
483
+
484
+ def recursive_munch(d):
485
+ if isinstance(d, dict):
486
+ return Munch((k, recursive_munch(v)) for k, v in d.items())
487
+ elif isinstance(d, list):
488
+ return [recursive_munch(v) for v in d]
489
+ else:
490
+ return d