haoheliu commited on
Commit
f3ecefe
·
1 Parent(s): 9d3ac98
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +2 -2
  2. audioldm2/__init__.py +0 -2
  3. audioldm2/audiomae_gen/__init__.py +0 -1
  4. audioldm2/audiomae_gen/sequence_input.py +0 -429
  5. audioldm2/audiomae_gen/utils.py +0 -27
  6. audioldm2/clap/__init__.py +0 -0
  7. audioldm2/clap/open_clip/__init__.py +0 -25
  8. audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz +0 -3
  9. audioldm2/clap/open_clip/factory.py +0 -276
  10. audioldm2/clap/open_clip/feature_fusion.py +0 -192
  11. audioldm2/clap/open_clip/htsat.py +0 -1304
  12. audioldm2/clap/open_clip/loss.py +0 -397
  13. audioldm2/clap/open_clip/model.py +0 -931
  14. audioldm2/clap/open_clip/model_configs/HTSAT-base.json +0 -23
  15. audioldm2/clap/open_clip/model_configs/HTSAT-large.json +0 -23
  16. audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json +0 -23
  17. audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json +0 -23
  18. audioldm2/clap/open_clip/model_configs/PANN-10.json +0 -23
  19. audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json +0 -23
  20. audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json +0 -23
  21. audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json +0 -23
  22. audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json +0 -23
  23. audioldm2/clap/open_clip/model_configs/PANN-14.json +0 -23
  24. audioldm2/clap/open_clip/model_configs/PANN-6.json +0 -23
  25. audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json +0 -22
  26. audioldm2/clap/open_clip/model_configs/RN101.json +0 -21
  27. audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json +0 -22
  28. audioldm2/clap/open_clip/model_configs/RN50.json +0 -21
  29. audioldm2/clap/open_clip/model_configs/RN50x16.json +0 -21
  30. audioldm2/clap/open_clip/model_configs/RN50x4.json +0 -21
  31. audioldm2/clap/open_clip/model_configs/ViT-B-16.json +0 -16
  32. audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json +0 -17
  33. audioldm2/clap/open_clip/model_configs/ViT-B-32.json +0 -16
  34. audioldm2/clap/open_clip/model_configs/ViT-L-14.json +0 -16
  35. audioldm2/clap/open_clip/openai.py +0 -156
  36. audioldm2/clap/open_clip/pann_model.py +0 -697
  37. audioldm2/clap/open_clip/pretrained.py +0 -167
  38. audioldm2/clap/open_clip/timm_model.py +0 -112
  39. audioldm2/clap/open_clip/tokenizer.py +0 -197
  40. audioldm2/clap/open_clip/transform.py +0 -45
  41. audioldm2/clap/open_clip/utils.py +0 -356
  42. audioldm2/clap/training/__init__.py +0 -0
  43. audioldm2/clap/training/audioset_textmap.npy +0 -3
  44. audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz +0 -3
  45. audioldm2/clap/training/data.py +0 -865
  46. audioldm2/clap/training/params.py +0 -563
  47. audioldm2/hifigan/LICENSE +0 -21
  48. audioldm2/hifigan/__init__.py +0 -8
  49. audioldm2/hifigan/models.py +0 -174
  50. audioldm2/hifigan/models_v2.py +0 -395
app.py CHANGED
@@ -2,12 +2,12 @@ from huggingface_hub import hf_hub_download
2
  import torch
3
  import os
4
 
5
- os.environ["TOKENIZERS_PARALLELISM"] = "true"
6
-
7
  import gradio as gr
8
  from audioldm2 import text_to_audio, build_model
9
  from share_btn import community_icon_html, loading_icon_html, share_js
10
 
 
 
11
  model_id = "haoheliu/audioldm2-full"
12
  hf_hub_download(repo_id="haoheliu/audioldm2-full", filename="audioldm2-full.pth")
13
 
 
2
  import torch
3
  import os
4
 
 
 
5
  import gradio as gr
6
  from audioldm2 import text_to_audio, build_model
7
  from share_btn import community_icon_html, loading_icon_html, share_js
8
 
9
+ os.environ["TOKENIZERS_PARALLELISM"] = "true"
10
+
11
  model_id = "haoheliu/audioldm2-full"
12
  hf_hub_download(repo_id="haoheliu/audioldm2-full", filename="audioldm2-full.pth")
13
 
audioldm2/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .utils import seed_everything, save_wave, get_time, get_duration, read_list
2
- from .pipeline import *
 
 
 
audioldm2/audiomae_gen/__init__.py DELETED
@@ -1 +0,0 @@
1
- from .sequence_input import Sequence2AudioMAE
 
 
audioldm2/audiomae_gen/sequence_input.py DELETED
@@ -1,429 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from audioldm2.latent_diffusion.util import (
4
- instantiate_from_config,
5
- )
6
-
7
- # from latent_diffusion.modules.encoders.modules import CLAPAudioEmbeddingClassifierFreev2
8
- from transformers import GPT2Config, GPT2Model
9
- import torch.optim.lr_scheduler as lr_scheduler
10
-
11
- class Sequence2AudioMAE(nn.Module):
12
- def __init__(
13
- self,
14
- base_learning_rate,
15
- sequence_gen_length,
16
- sequence_input_key,
17
- sequence_input_embed_dim,
18
- cond_stage_config,
19
- optimizer_type="AdamW",
20
- use_warmup=True,
21
- use_ar_gen_loss=False,
22
- use_audiomae_linear=False,
23
- target_tokens_mask_ratio=0.0,
24
- random_mask_ratio=False,
25
- **kwargs
26
- ):
27
- super().__init__()
28
- assert use_audiomae_linear == False
29
- self.random_mask_ratio = random_mask_ratio
30
- self.learning_rate = base_learning_rate
31
- self.cond_stage_config = cond_stage_config
32
- self.use_audiomae_linear = use_audiomae_linear
33
- self.optimizer_type = optimizer_type
34
- self.use_warmup = use_warmup
35
- self.use_ar_gen_loss = use_ar_gen_loss
36
- # Even though the LDM can be conditioned on mutliple pooling rate
37
- # Our model always predict the higest pooling rate
38
-
39
- # self.time_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["time_pooling_factors"])
40
- # self.freq_pool = max(self.cond_stage_config["crossattn_audiomae_pooled"]["params"]["freq_pooling_factors"])
41
- # self.mae_token_num = int(512/(self.time_pool*self.freq_pool))
42
-
43
- self.mae_token_num = sequence_gen_length
44
- self.sequence_input_key = sequence_input_key
45
- self.sequence_input_embed_dim = sequence_input_embed_dim
46
- self.target_tokens_mask_ratio = target_tokens_mask_ratio
47
-
48
- self.start_of_sequence_tokens = nn.Embedding(32, 768)
49
- self.end_of_sequence_tokens = nn.Embedding(32, 768)
50
-
51
- self.input_sequence_embed_linear = nn.ModuleList([])
52
- self.initial_learning_rate = None
53
-
54
- for dim in self.sequence_input_embed_dim:
55
- self.input_sequence_embed_linear.append(nn.Linear(dim, 768))
56
-
57
- self.cond_stage_models = nn.ModuleList([])
58
- self.instantiate_cond_stage(cond_stage_config)
59
- self.initialize_param_check_toolkit()
60
-
61
- # configuration = GPT2Config(n_layer=1) # TODO
62
- # self.model=GPT2Model(configuration)
63
- ###################
64
- # self.model=nn.Linear(768,768, bias=False) # TODO change the model
65
- # with torch.no_grad():
66
- # self.model.weight.copy_(torch.eye(768))
67
- ###################
68
- self.model = GPT2Model(GPT2Config.from_pretrained("gpt2"))
69
- ###################
70
- # self.model = nn.LSTM(input_size=768, hidden_size=768, num_layers=1,bias=False) # TODO
71
-
72
- # self.loss_fn = nn.MSELoss()
73
- self.loss_fn = nn.L1Loss()
74
-
75
- self.logger_save_dir = None
76
- self.logger_exp_name = None
77
- self.logger_exp_group_name = None
78
- self.logger_version = None
79
-
80
- def set_log_dir(self, save_dir, exp_group_name, exp_name):
81
- self.logger_save_dir = save_dir
82
- self.logger_exp_group_name = exp_group_name
83
- self.logger_exp_name = exp_name
84
-
85
- def cfg_uncond(self, batch_size):
86
- unconditional_conditioning = {}
87
- for key in self.cond_stage_model_metadata:
88
- model_idx = self.cond_stage_model_metadata[key]["model_idx"]
89
- unconditional_conditioning[key] = self.cond_stage_models[
90
- model_idx
91
- ].get_unconditional_condition(batch_size)
92
- assert (
93
- "crossattn_audiomae_pooled" in unconditional_conditioning.keys()
94
- ), "The module is not initialized with AudioMAE"
95
- unconditional_conditioning[
96
- "crossattn_clap_to_audiomae_feature"
97
- ] = unconditional_conditioning["crossattn_audiomae_pooled"]
98
- return unconditional_conditioning
99
-
100
- def configure_optimizers(self):
101
- lr = float(self.learning_rate)
102
- # params = list(self.model.parameters()) + list(self.input_sequence_embed_linear.parameters())
103
- params = list(self.parameters())
104
-
105
- # opt = torch.optim.Adam(params, lr=lr, betas=(0.9, 0.98), eps=1e-9)
106
- opt = eval(self.optimizer_type)(params, lr=lr)
107
- scheduler = lr_scheduler.StepLR(opt, step_size=10, gamma=0.8)
108
- return [opt], [scheduler]
109
-
110
- def add_sos_eos_tokens(self, _id, sequence, attn_mask):
111
- batchsize = sequence.size(0)
112
-
113
- new_attn_mask_step = torch.ones((batchsize, 1)).to(sequence.device)
114
- key_id = torch.tensor([_id]).to(sequence.device)
115
-
116
- # Add two more steps to attn mask
117
- new_attn_mask = torch.cat(
118
- [new_attn_mask_step, attn_mask, new_attn_mask_step], dim=1
119
- )
120
-
121
- # Add two more tokens in the sequence
122
- sos_token = self.start_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
123
- eos_token = self.end_of_sequence_tokens(key_id).expand(batchsize, 1, -1)
124
- new_sequence = torch.cat([sos_token, sequence, eos_token], dim=1)
125
- return new_sequence, new_attn_mask
126
-
127
- def truncate_sequence_and_mask(self, sequence, mask, max_len=512):
128
- if sequence.size(1) > max_len:
129
- print(
130
- "The input sequence length to GPT-2 model is too long:",
131
- sequence.size(1),
132
- )
133
- return sequence[:, :max_len], mask[:, :max_len]
134
- else:
135
- return sequence, mask
136
-
137
- def get_input_sequence_and_mask(self, cond_dict):
138
- input_embeds = None
139
- input_embeds_attn_mask = None
140
- for _id, sequence_key in enumerate(self.sequence_input_key):
141
- assert sequence_key in cond_dict.keys(), (
142
- "Invalid sequence key %s" % sequence_key
143
- )
144
- cond_embed = cond_dict[sequence_key]
145
- if isinstance(cond_embed, list):
146
- assert (
147
- len(cond_embed) == 2
148
- ), "The crossattn returned list should have length 2, including embed and attn_mask"
149
- item_input_embeds, item_attn_mask = cond_embed
150
-
151
- item_input_embeds = self.input_sequence_embed_linear[_id](
152
- item_input_embeds
153
- )
154
-
155
- item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
156
- _id, item_input_embeds, item_attn_mask
157
- )
158
-
159
- if input_embeds is None and input_embeds_attn_mask is None:
160
- input_embeds, input_embeds_attn_mask = (
161
- item_input_embeds,
162
- item_attn_mask,
163
- )
164
- else:
165
- input_embeds = torch.cat(
166
- [input_embeds, item_input_embeds], dim=1
167
- ) # The 1-st dimension is time steps
168
- input_embeds_attn_mask = torch.cat(
169
- [input_embeds_attn_mask, item_attn_mask], dim=1
170
- ) # The 1-st dimension is time steps
171
- else:
172
- assert isinstance(cond_embed, torch.Tensor)
173
- cond_embed = self.input_sequence_embed_linear[_id](cond_embed)
174
- attn_mask = torch.ones((cond_embed.size(0), cond_embed.size(1))).to(
175
- cond_embed.device
176
- )
177
-
178
- item_input_embeds, item_attn_mask = self.add_sos_eos_tokens(
179
- _id, cond_embed, attn_mask
180
- )
181
-
182
- if input_embeds is None and input_embeds_attn_mask is None:
183
- input_embeds, input_embeds_attn_mask = (
184
- item_input_embeds,
185
- item_attn_mask,
186
- )
187
- else:
188
- input_embeds, input_embeds_attn_mask = torch.cat(
189
- [input_embeds, item_input_embeds], dim=1
190
- ), torch.cat([input_embeds_attn_mask, item_attn_mask], dim=1)
191
-
192
- assert input_embeds is not None and input_embeds_attn_mask is not None
193
-
194
- input_embeds, input_embeds_attn_mask = self.truncate_sequence_and_mask(
195
- input_embeds, input_embeds_attn_mask, int(1024 - self.mae_token_num)
196
- )
197
- cond_sequence_end_time_idx = input_embeds.size(
198
- 1
199
- ) # The index that we start to collect the output embeds
200
-
201
- return input_embeds, input_embeds_attn_mask, cond_sequence_end_time_idx
202
-
203
- def warmup_step(self):
204
- if self.initial_learning_rate is None:
205
- self.initial_learning_rate = float(self.learning_rate)
206
-
207
- # Only the first parameter group
208
- if self.global_step <= 1000:
209
- if self.global_step == 0:
210
- print(
211
- "Warming up learning rate start with %s"
212
- % self.initial_learning_rate
213
- )
214
- self.trainer.optimizers[0].param_groups[0]["lr"] = (
215
- self.global_step / 1000
216
- ) * self.initial_learning_rate
217
- else:
218
- # TODO set learning rate here
219
- self.trainer.optimizers[0].param_groups[0][
220
- "lr"
221
- ] = self.initial_learning_rate
222
-
223
- def mask_target_sequence(self, target_embeds, target_embeds_attn_mask):
224
- time_seq_mask = None
225
- if self.target_tokens_mask_ratio > 1e-4:
226
- batchsize, time_seq_len, embed_dim = target_embeds.size()
227
- _, time_seq_len = target_embeds_attn_mask.size()
228
- # Generate random mask
229
- if self.random_mask_ratio:
230
- mask_ratio = torch.rand(1).item() * self.target_tokens_mask_ratio
231
- else:
232
- mask_ratio = self.target_tokens_mask_ratio
233
-
234
- time_seq_mask = (torch.rand((batchsize, time_seq_len)) > mask_ratio).to(
235
- target_embeds.device
236
- )
237
- # Mask the target embedding
238
- target_embeds = target_embeds * time_seq_mask.unsqueeze(-1)
239
- target_embeds_attn_mask = target_embeds_attn_mask * time_seq_mask
240
- return target_embeds, target_embeds_attn_mask, time_seq_mask
241
-
242
- def generate_partial(self, batch, cond_dict=None, no_grad=False):
243
- if cond_dict is None:
244
- cond_dict = self.get_input(batch)
245
-
246
- print("Generate partially prompted audio with in-context learning")
247
- # self.model.train()
248
- # assert self.model.training==True
249
-
250
- target_embeds, target_embeds_attn_mask = (
251
- cond_dict["crossattn_audiomae_pooled"][0],
252
- cond_dict["crossattn_audiomae_pooled"][1],
253
- )
254
-
255
- target_time_steps = target_embeds.size(1)
256
-
257
- (
258
- input_embeds,
259
- input_embeds_attn_mask,
260
- cond_sequence_end_time_idx,
261
- ) = self.get_input_sequence_and_mask(cond_dict)
262
-
263
- model_input = torch.cat(
264
- [input_embeds, target_embeds[:, : target_time_steps // 4, :]], dim=1
265
- )
266
- model_input_mask = torch.cat(
267
- [
268
- input_embeds_attn_mask,
269
- target_embeds_attn_mask[:, : target_time_steps // 4],
270
- ],
271
- dim=1,
272
- )
273
-
274
- steps = self.mae_token_num
275
-
276
- for _ in range(3 * steps // 4):
277
- output = self.model(
278
- inputs_embeds=model_input, attention_mask=model_input_mask
279
- )["last_hidden_state"]
280
- # Update the model input
281
- model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
282
- # Update the attention mask
283
- attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
284
- model_input.device
285
- )
286
- model_input_mask = torch.cat(
287
- [model_input_mask, attention_mask_new_step], dim=1
288
- )
289
-
290
- output = model_input[:, cond_sequence_end_time_idx:]
291
-
292
- return output, cond_dict
293
-
294
- def generate(self, batch, cond_dict=None, no_grad=False):
295
- if cond_dict is None:
296
- cond_dict = self.get_input(batch)
297
-
298
- # self.model.train()
299
- # print("!!!!!!!!!!!!!train")
300
-
301
- (
302
- input_embeds,
303
- input_embeds_attn_mask,
304
- cond_sequence_end_time_idx,
305
- ) = self.get_input_sequence_and_mask(cond_dict)
306
- model_input = input_embeds
307
- model_input_mask = input_embeds_attn_mask
308
-
309
- steps = self.mae_token_num
310
-
311
- for _ in range(steps):
312
- output = self.model(
313
- inputs_embeds=model_input, attention_mask=model_input_mask
314
- )["last_hidden_state"]
315
- # Update the model input
316
- model_input = torch.cat([model_input, output[:, -1:, :]], dim=1)
317
- # Update the attention mask
318
- attention_mask_new_step = torch.ones((model_input_mask.size(0), 1)).to(
319
- model_input.device
320
- )
321
- model_input_mask = torch.cat(
322
- [model_input_mask, attention_mask_new_step], dim=1
323
- )
324
-
325
- return model_input[:, cond_sequence_end_time_idx:], cond_dict
326
-
327
- def get_input_item(self, batch, k):
328
- fname, text, waveform, stft, fbank = (
329
- batch["fname"],
330
- batch["text"],
331
- batch["waveform"],
332
- batch["stft"],
333
- batch["log_mel_spec"],
334
- )
335
- ret = {}
336
-
337
- ret["fbank"] = (
338
- fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float()
339
- )
340
- ret["stft"] = stft.to(memory_format=torch.contiguous_format).float()
341
- # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float()
342
- ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float()
343
- ret["text"] = list(text)
344
- ret["fname"] = fname
345
-
346
- for key in batch.keys():
347
- if key not in ret.keys():
348
- ret[key] = batch[key]
349
-
350
- return ret[k]
351
-
352
- def get_input(self, batch):
353
- cond_dict = {}
354
- if len(self.cond_stage_model_metadata.keys()) > 0:
355
- unconditional_cfg = False
356
-
357
- for cond_model_key in self.cond_stage_model_metadata.keys():
358
- cond_stage_key = self.cond_stage_model_metadata[cond_model_key][
359
- "cond_stage_key"
360
- ]
361
-
362
- # if(not self.training):
363
- # if(isinstance(self.cond_stage_models[self.cond_stage_model_metadata[cond_model_key]["model_idx"]], CLAPAudioEmbeddingClassifierFreev2)):
364
- # assert cond_stage_key == "text" # CLAP model should use text for evaluation
365
-
366
- # The original data for conditioning
367
- xc = self.get_input_item(batch, cond_stage_key)
368
- if type(xc) == torch.Tensor:
369
- xc = xc.to(self.device)
370
-
371
- c = self.get_learned_conditioning(
372
- xc, key=cond_model_key, unconditional_cfg=unconditional_cfg
373
- )
374
- cond_dict[cond_model_key] = c
375
-
376
- return cond_dict
377
-
378
- def instantiate_cond_stage(self, config):
379
- self.cond_stage_model_metadata = {}
380
-
381
- for i, cond_model_key in enumerate(config.keys()):
382
- model = instantiate_from_config(config[cond_model_key])
383
- self.cond_stage_models.append(model)
384
- self.cond_stage_model_metadata[cond_model_key] = {
385
- "model_idx": i,
386
- "cond_stage_key": config[cond_model_key]["cond_stage_key"],
387
- "conditioning_key": config[cond_model_key]["conditioning_key"],
388
- }
389
-
390
- def get_learned_conditioning(self, c, key, unconditional_cfg):
391
- assert key in self.cond_stage_model_metadata.keys()
392
-
393
- # Classifier-free guidance
394
- if not unconditional_cfg:
395
- c = self.cond_stage_models[
396
- self.cond_stage_model_metadata[key]["model_idx"]
397
- ](c)
398
- else:
399
- if isinstance(c, torch.Tensor):
400
- batchsize = c.size(0)
401
- elif isinstance(c, list):
402
- batchsize = len(c)
403
- else:
404
- raise NotImplementedError()
405
- c = self.cond_stage_models[
406
- self.cond_stage_model_metadata[key]["model_idx"]
407
- ].get_unconditional_condition(batchsize)
408
-
409
- return c
410
-
411
- def initialize_param_check_toolkit(self):
412
- self.tracked_steps = 0
413
- self.param_dict = {}
414
-
415
- def statistic_require_grad_tensor_number(self, module, name=None):
416
- requires_grad_num = 0
417
- total_num = 0
418
- require_grad_tensor = None
419
- for p in module.parameters():
420
- if p.requires_grad:
421
- requires_grad_num += 1
422
- if require_grad_tensor is None:
423
- require_grad_tensor = p
424
- total_num += 1
425
- print(
426
- "Module: [%s] have %s trainable parameters out of %s total parameters (%.2f)"
427
- % (name, requires_grad_num, total_num, requires_grad_num / total_num)
428
- )
429
- return require_grad_tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/audiomae_gen/utils.py DELETED
@@ -1,27 +0,0 @@
1
- import torch.nn as nn
2
-
3
-
4
- class Prenet(nn.Module):
5
- def __init__(self, in_dim, sizes=[256, 128], dropout_rate=0.5):
6
- super(Prenet, self).__init__()
7
- in_sizes = [in_dim] + sizes[:-1]
8
- self.layers = nn.ModuleList(
9
- [
10
- nn.Linear(in_size, out_size)
11
- for (in_size, out_size) in zip(in_sizes, sizes)
12
- ]
13
- )
14
- self.relu = nn.ReLU()
15
- self.dropout = nn.Dropout(dropout_rate)
16
-
17
- def forward(self, inputs):
18
- for linear in self.layers:
19
- inputs = self.dropout(self.relu(linear(inputs)))
20
- return inputs
21
-
22
-
23
- if __name__ == "__main__":
24
- model = Prenet(in_dim=128, sizes=[256, 256, 128])
25
- import ipdb
26
-
27
- ipdb.set_trace()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/__init__.py DELETED
File without changes
audioldm2/clap/open_clip/__init__.py DELETED
@@ -1,25 +0,0 @@
1
- from .factory import (
2
- list_models,
3
- create_model,
4
- create_model_and_transforms,
5
- add_model_config,
6
- )
7
- from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics
8
- from .model import (
9
- CLAP,
10
- CLAPTextCfg,
11
- CLAPVisionCfg,
12
- CLAPAudioCfp,
13
- convert_weights_to_fp16,
14
- trace_model,
15
- )
16
- from .openai import load_openai_model, list_openai_models
17
- from .pretrained import (
18
- list_pretrained,
19
- list_pretrained_tag_models,
20
- list_pretrained_model_tags,
21
- get_pretrained_url,
22
- download_pretrained,
23
- )
24
- from .tokenizer import SimpleTokenizer, tokenize
25
- from .transform import image_transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
audioldm2/clap/open_clip/factory.py DELETED
@@ -1,276 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import re
5
- from copy import deepcopy
6
- from pathlib import Path
7
-
8
- import torch
9
-
10
- from .model import CLAP, convert_weights_to_fp16
11
- from .openai import load_openai_model
12
- from .pretrained import get_pretrained_url, download_pretrained
13
- from .transform import image_transform
14
-
15
- _MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
16
- _MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
17
-
18
-
19
- def _natural_key(string_):
20
- return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
21
-
22
-
23
- def _rescan_model_configs():
24
- global _MODEL_CONFIGS
25
-
26
- config_ext = (".json",)
27
- config_files = []
28
- for config_path in _MODEL_CONFIG_PATHS:
29
- if config_path.is_file() and config_path.suffix in config_ext:
30
- config_files.append(config_path)
31
- elif config_path.is_dir():
32
- for ext in config_ext:
33
- config_files.extend(config_path.glob(f"*{ext}"))
34
-
35
- for cf in config_files:
36
- if os.path.basename(cf)[0] == ".":
37
- continue # Ignore hidden files
38
-
39
- with open(cf, "r") as f:
40
- model_cfg = json.load(f)
41
- if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")):
42
- _MODEL_CONFIGS[cf.stem] = model_cfg
43
-
44
- _MODEL_CONFIGS = {
45
- k: v
46
- for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))
47
- }
48
-
49
-
50
- _rescan_model_configs() # initial populate of model config registry
51
-
52
-
53
- def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True):
54
- checkpoint = torch.load(checkpoint_path, map_location=map_location)
55
- if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
56
- state_dict = checkpoint["state_dict"]
57
- else:
58
- state_dict = checkpoint
59
- if skip_params:
60
- if next(iter(state_dict.items()))[0].startswith("module"):
61
- state_dict = {k[7:]: v for k, v in state_dict.items()}
62
- # for k in state_dict:
63
- # if k.startswith('transformer'):
64
- # v = state_dict.pop(k)
65
- # state_dict['text_branch.' + k[12:]] = v
66
- return state_dict
67
-
68
-
69
- def create_model(
70
- amodel_name: str,
71
- tmodel_name: str,
72
- pretrained: str = "",
73
- precision: str = "fp32",
74
- device: torch.device = torch.device("cpu"),
75
- jit: bool = False,
76
- force_quick_gelu: bool = False,
77
- openai_model_cache_dir: str = os.path.expanduser("~/.cache/clip"),
78
- skip_params=True,
79
- pretrained_audio: str = "",
80
- pretrained_text: str = "",
81
- enable_fusion: bool = False,
82
- fusion_type: str = "None"
83
- # pretrained_image: bool = False,
84
- ):
85
- amodel_name = amodel_name.replace(
86
- "/", "-"
87
- ) # for callers using old naming with / in ViT names
88
- pretrained_orig = pretrained
89
- pretrained = pretrained.lower()
90
- if pretrained == "openai":
91
- if amodel_name in _MODEL_CONFIGS:
92
- logging.info(f"Loading {amodel_name} model config.")
93
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
94
- else:
95
- logging.error(
96
- f"Model config for {amodel_name} not found; available models {list_models()}."
97
- )
98
- raise RuntimeError(f"Model config for {amodel_name} not found.")
99
-
100
- logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.")
101
- # Hard Code in model name
102
- model_cfg["text_cfg"]["model_type"] = tmodel_name
103
- model = load_openai_model(
104
- "ViT-B-16",
105
- model_cfg,
106
- device=device,
107
- jit=jit,
108
- cache_dir=openai_model_cache_dir,
109
- enable_fusion=enable_fusion,
110
- fusion_type=fusion_type,
111
- )
112
- # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372
113
- if precision == "amp" or precision == "fp32":
114
- model = model.float()
115
- else:
116
- if amodel_name in _MODEL_CONFIGS:
117
- logging.info(f"Loading {amodel_name} model config.")
118
- model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name])
119
- else:
120
- logging.error(
121
- f"Model config for {amodel_name} not found; available models {list_models()}."
122
- )
123
- raise RuntimeError(f"Model config for {amodel_name} not found.")
124
-
125
- if force_quick_gelu:
126
- # override for use of QuickGELU on non-OpenAI transformer models
127
- model_cfg["quick_gelu"] = True
128
-
129
- # if pretrained_image:
130
- # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}):
131
- # # pretrained weight loading for timm models set via vision_cfg
132
- # model_cfg['vision_cfg']['timm_model_pretrained'] = True
133
- # else:
134
- # assert False, 'pretrained image towers currently only supported for timm models'
135
- model_cfg["text_cfg"]["model_type"] = tmodel_name
136
- model_cfg["enable_fusion"] = enable_fusion
137
- model_cfg["fusion_type"] = fusion_type
138
- model = CLAP(**model_cfg)
139
-
140
- if pretrained:
141
- checkpoint_path = ""
142
- url = get_pretrained_url(amodel_name, pretrained)
143
- if url:
144
- checkpoint_path = download_pretrained(url, root=openai_model_cache_dir)
145
- elif os.path.exists(pretrained_orig):
146
- checkpoint_path = pretrained_orig
147
- if checkpoint_path:
148
- logging.info(
149
- f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})."
150
- )
151
- ckpt = load_state_dict(checkpoint_path, skip_params=True)
152
- model.load_state_dict(ckpt)
153
- param_names = [n for n, p in model.named_parameters()]
154
- # for n in param_names:
155
- # print(n, "\t", "Loaded" if n in ckpt else "Unloaded")
156
- else:
157
- logging.warning(
158
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
159
- )
160
- raise RuntimeError(
161
- f"Pretrained weights ({pretrained}) not found for model {amodel_name}."
162
- )
163
-
164
- if pretrained_audio:
165
- if amodel_name.startswith("PANN"):
166
- if "Cnn14_mAP" in pretrained_audio: # official checkpoint
167
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
168
- audio_ckpt = audio_ckpt["model"]
169
- keys = list(audio_ckpt.keys())
170
- for key in keys:
171
- if (
172
- "spectrogram_extractor" not in key
173
- and "logmel_extractor" not in key
174
- ):
175
- v = audio_ckpt.pop(key)
176
- audio_ckpt["audio_branch." + key] = v
177
- elif os.path.basename(pretrained_audio).startswith(
178
- "PANN"
179
- ): # checkpoint trained via HTSAT codebase
180
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
181
- audio_ckpt = audio_ckpt["state_dict"]
182
- keys = list(audio_ckpt.keys())
183
- for key in keys:
184
- if key.startswith("sed_model"):
185
- v = audio_ckpt.pop(key)
186
- audio_ckpt["audio_branch." + key[10:]] = v
187
- elif os.path.basename(pretrained_audio).startswith(
188
- "finetuned"
189
- ): # checkpoint trained via linear probe codebase
190
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
191
- else:
192
- raise ValueError("Unknown audio checkpoint")
193
- elif amodel_name.startswith("HTSAT"):
194
- if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint
195
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
196
- audio_ckpt = audio_ckpt["state_dict"]
197
- keys = list(audio_ckpt.keys())
198
- for key in keys:
199
- if key.startswith("sed_model") and (
200
- "spectrogram_extractor" not in key
201
- and "logmel_extractor" not in key
202
- ):
203
- v = audio_ckpt.pop(key)
204
- audio_ckpt["audio_branch." + key[10:]] = v
205
- elif os.path.basename(pretrained_audio).startswith(
206
- "HTSAT"
207
- ): # checkpoint trained via HTSAT codebase
208
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
209
- audio_ckpt = audio_ckpt["state_dict"]
210
- keys = list(audio_ckpt.keys())
211
- for key in keys:
212
- if key.startswith("sed_model"):
213
- v = audio_ckpt.pop(key)
214
- audio_ckpt["audio_branch." + key[10:]] = v
215
- elif os.path.basename(pretrained_audio).startswith(
216
- "finetuned"
217
- ): # checkpoint trained via linear probe codebase
218
- audio_ckpt = torch.load(pretrained_audio, map_location="cpu")
219
- else:
220
- raise ValueError("Unknown audio checkpoint")
221
- else:
222
- raise f"this audio encoder pretrained checkpoint is not support"
223
-
224
- model.load_state_dict(audio_ckpt, strict=False)
225
- logging.info(
226
- f"Loading pretrained {amodel_name} weights ({pretrained_audio})."
227
- )
228
- param_names = [n for n, p in model.named_parameters()]
229
- for n in param_names:
230
- print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded")
231
-
232
- model.to(device=device)
233
- if precision == "fp16":
234
- assert device.type != "cpu"
235
- convert_weights_to_fp16(model)
236
-
237
- if jit:
238
- model = torch.jit.script(model)
239
-
240
- return model, model_cfg
241
-
242
-
243
- def create_model_and_transforms(
244
- model_name: str,
245
- pretrained: str = "",
246
- precision: str = "fp32",
247
- device: torch.device = torch.device("cpu"),
248
- jit: bool = False,
249
- force_quick_gelu: bool = False,
250
- # pretrained_image: bool = False,
251
- ):
252
- model = create_model(
253
- model_name,
254
- pretrained,
255
- precision,
256
- device,
257
- jit,
258
- force_quick_gelu=force_quick_gelu,
259
- # pretrained_image=pretrained_image
260
- )
261
- preprocess_train = image_transform(model.visual.image_size, is_train=True)
262
- preprocess_val = image_transform(model.visual.image_size, is_train=False)
263
- return model, preprocess_train, preprocess_val
264
-
265
-
266
- def list_models():
267
- """enumerate available model architectures based on config files"""
268
- return list(_MODEL_CONFIGS.keys())
269
-
270
-
271
- def add_model_config(path):
272
- """add model config path or file and update registry"""
273
- if not isinstance(path, Path):
274
- path = Path(path)
275
- _MODEL_CONFIG_PATHS.append(path)
276
- _rescan_model_configs()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/feature_fusion.py DELETED
@@ -1,192 +0,0 @@
1
- """
2
- Feature Fusion for Varible-Length Data Processing
3
- AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py
4
- According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021
5
- """
6
-
7
- import torch
8
- import torch.nn as nn
9
-
10
-
11
- class DAF(nn.Module):
12
- """
13
- 直接相加 DirectAddFuse
14
- """
15
-
16
- def __init__(self):
17
- super(DAF, self).__init__()
18
-
19
- def forward(self, x, residual):
20
- return x + residual
21
-
22
-
23
- class iAFF(nn.Module):
24
- """
25
- 多特征融合 iAFF
26
- """
27
-
28
- def __init__(self, channels=64, r=4, type="2D"):
29
- super(iAFF, self).__init__()
30
- inter_channels = int(channels // r)
31
-
32
- if type == "1D":
33
- # 本地注意力
34
- self.local_att = nn.Sequential(
35
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
36
- nn.BatchNorm1d(inter_channels),
37
- nn.ReLU(inplace=True),
38
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
39
- nn.BatchNorm1d(channels),
40
- )
41
-
42
- # 全局注意力
43
- self.global_att = nn.Sequential(
44
- nn.AdaptiveAvgPool1d(1),
45
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
46
- nn.BatchNorm1d(inter_channels),
47
- nn.ReLU(inplace=True),
48
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
49
- nn.BatchNorm1d(channels),
50
- )
51
-
52
- # 第二次本地注意力
53
- self.local_att2 = nn.Sequential(
54
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
55
- nn.BatchNorm1d(inter_channels),
56
- nn.ReLU(inplace=True),
57
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
58
- nn.BatchNorm1d(channels),
59
- )
60
- # 第二次全局注意力
61
- self.global_att2 = nn.Sequential(
62
- nn.AdaptiveAvgPool1d(1),
63
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
64
- nn.BatchNorm1d(inter_channels),
65
- nn.ReLU(inplace=True),
66
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
67
- nn.BatchNorm1d(channels),
68
- )
69
- elif type == "2D":
70
- # 本地注意力
71
- self.local_att = nn.Sequential(
72
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
73
- nn.BatchNorm2d(inter_channels),
74
- nn.ReLU(inplace=True),
75
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
76
- nn.BatchNorm2d(channels),
77
- )
78
-
79
- # 全局注意力
80
- self.global_att = nn.Sequential(
81
- nn.AdaptiveAvgPool2d(1),
82
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
83
- nn.BatchNorm2d(inter_channels),
84
- nn.ReLU(inplace=True),
85
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
86
- nn.BatchNorm2d(channels),
87
- )
88
-
89
- # 第二次本地注意力
90
- self.local_att2 = nn.Sequential(
91
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
92
- nn.BatchNorm2d(inter_channels),
93
- nn.ReLU(inplace=True),
94
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
95
- nn.BatchNorm2d(channels),
96
- )
97
- # 第二次全局注意力
98
- self.global_att2 = nn.Sequential(
99
- nn.AdaptiveAvgPool2d(1),
100
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
101
- nn.BatchNorm2d(inter_channels),
102
- nn.ReLU(inplace=True),
103
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
104
- nn.BatchNorm2d(channels),
105
- )
106
- else:
107
- raise f"the type is not supported"
108
-
109
- self.sigmoid = nn.Sigmoid()
110
-
111
- def forward(self, x, residual):
112
- flag = False
113
- xa = x + residual
114
- if xa.size(0) == 1:
115
- xa = torch.cat([xa, xa], dim=0)
116
- flag = True
117
- xl = self.local_att(xa)
118
- xg = self.global_att(xa)
119
- xlg = xl + xg
120
- wei = self.sigmoid(xlg)
121
- xi = x * wei + residual * (1 - wei)
122
-
123
- xl2 = self.local_att2(xi)
124
- xg2 = self.global_att(xi)
125
- xlg2 = xl2 + xg2
126
- wei2 = self.sigmoid(xlg2)
127
- xo = x * wei2 + residual * (1 - wei2)
128
- if flag:
129
- xo = xo[0].unsqueeze(0)
130
- return xo
131
-
132
-
133
- class AFF(nn.Module):
134
- """
135
- 多特征融合 AFF
136
- """
137
-
138
- def __init__(self, channels=64, r=4, type="2D"):
139
- super(AFF, self).__init__()
140
- inter_channels = int(channels // r)
141
-
142
- if type == "1D":
143
- self.local_att = nn.Sequential(
144
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
145
- nn.BatchNorm1d(inter_channels),
146
- nn.ReLU(inplace=True),
147
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
148
- nn.BatchNorm1d(channels),
149
- )
150
- self.global_att = nn.Sequential(
151
- nn.AdaptiveAvgPool1d(1),
152
- nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
153
- nn.BatchNorm1d(inter_channels),
154
- nn.ReLU(inplace=True),
155
- nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
156
- nn.BatchNorm1d(channels),
157
- )
158
- elif type == "2D":
159
- self.local_att = nn.Sequential(
160
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
161
- nn.BatchNorm2d(inter_channels),
162
- nn.ReLU(inplace=True),
163
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
164
- nn.BatchNorm2d(channels),
165
- )
166
- self.global_att = nn.Sequential(
167
- nn.AdaptiveAvgPool2d(1),
168
- nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0),
169
- nn.BatchNorm2d(inter_channels),
170
- nn.ReLU(inplace=True),
171
- nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
172
- nn.BatchNorm2d(channels),
173
- )
174
- else:
175
- raise f"the type is not supported."
176
-
177
- self.sigmoid = nn.Sigmoid()
178
-
179
- def forward(self, x, residual):
180
- flag = False
181
- xa = x + residual
182
- if xa.size(0) == 1:
183
- xa = torch.cat([xa, xa], dim=0)
184
- flag = True
185
- xl = self.local_att(xa)
186
- xg = self.global_att(xa)
187
- xlg = xl + xg
188
- wei = self.sigmoid(xlg)
189
- xo = 2 * x * wei + 2 * residual * (1 - wei)
190
- if flag:
191
- xo = xo[0].unsqueeze(0)
192
- return xo
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/htsat.py DELETED
@@ -1,1304 +0,0 @@
1
- # Ke Chen
2
3
- # HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION
4
- # Some layers designed on the model
5
- # below codes are based and referred from https://github.com/microsoft/Swin-Transformer
6
- # Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf
7
-
8
- import torch
9
- import torch.nn as nn
10
- from itertools import repeat
11
- import collections.abc
12
- import math
13
- import warnings
14
-
15
- from torch.nn.init import _calculate_fan_in_and_fan_out
16
- import torch.utils.checkpoint as checkpoint
17
-
18
- import random
19
-
20
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
21
- from torchlibrosa.augmentation import SpecAugmentation
22
-
23
- from itertools import repeat
24
- from .utils import do_mixup, interpolate
25
-
26
- from .feature_fusion import iAFF, AFF, DAF
27
-
28
-
29
- # from PyTorch internals
30
- def _ntuple(n):
31
- def parse(x):
32
- if isinstance(x, collections.abc.Iterable):
33
- return x
34
- return tuple(repeat(x, n))
35
-
36
- return parse
37
-
38
-
39
- to_1tuple = _ntuple(1)
40
- to_2tuple = _ntuple(2)
41
- to_3tuple = _ntuple(3)
42
- to_4tuple = _ntuple(4)
43
- to_ntuple = _ntuple
44
-
45
-
46
- def drop_path(x, drop_prob: float = 0.0, training: bool = False):
47
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
48
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
49
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
50
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
51
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
52
- 'survival rate' as the argument.
53
- """
54
- if drop_prob == 0.0 or not training:
55
- return x
56
- keep_prob = 1 - drop_prob
57
- shape = (x.shape[0],) + (1,) * (
58
- x.ndim - 1
59
- ) # work with diff dim tensors, not just 2D ConvNets
60
- random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
61
- random_tensor.floor_() # binarize
62
- output = x.div(keep_prob) * random_tensor
63
- return output
64
-
65
-
66
- class DropPath(nn.Module):
67
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
68
-
69
- def __init__(self, drop_prob=None):
70
- super(DropPath, self).__init__()
71
- self.drop_prob = drop_prob
72
-
73
- def forward(self, x):
74
- return drop_path(x, self.drop_prob, self.training)
75
-
76
-
77
- class PatchEmbed(nn.Module):
78
- """2D Image to Patch Embedding"""
79
-
80
- def __init__(
81
- self,
82
- img_size=224,
83
- patch_size=16,
84
- in_chans=3,
85
- embed_dim=768,
86
- norm_layer=None,
87
- flatten=True,
88
- patch_stride=16,
89
- enable_fusion=False,
90
- fusion_type="None",
91
- ):
92
- super().__init__()
93
- img_size = to_2tuple(img_size)
94
- patch_size = to_2tuple(patch_size)
95
- patch_stride = to_2tuple(patch_stride)
96
- self.img_size = img_size
97
- self.patch_size = patch_size
98
- self.patch_stride = patch_stride
99
- self.grid_size = (
100
- img_size[0] // patch_stride[0],
101
- img_size[1] // patch_stride[1],
102
- )
103
- self.num_patches = self.grid_size[0] * self.grid_size[1]
104
- self.flatten = flatten
105
- self.in_chans = in_chans
106
- self.embed_dim = embed_dim
107
-
108
- self.enable_fusion = enable_fusion
109
- self.fusion_type = fusion_type
110
-
111
- padding = (
112
- (patch_size[0] - patch_stride[0]) // 2,
113
- (patch_size[1] - patch_stride[1]) // 2,
114
- )
115
-
116
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
117
- self.proj = nn.Conv2d(
118
- in_chans * 4,
119
- embed_dim,
120
- kernel_size=patch_size,
121
- stride=patch_stride,
122
- padding=padding,
123
- )
124
- else:
125
- self.proj = nn.Conv2d(
126
- in_chans,
127
- embed_dim,
128
- kernel_size=patch_size,
129
- stride=patch_stride,
130
- padding=padding,
131
- )
132
- self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
133
-
134
- if (self.enable_fusion) and (
135
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
136
- ):
137
- self.mel_conv2d = nn.Conv2d(
138
- in_chans,
139
- embed_dim,
140
- kernel_size=(patch_size[0], patch_size[1] * 3),
141
- stride=(patch_stride[0], patch_stride[1] * 3),
142
- padding=padding,
143
- )
144
- if self.fusion_type == "daf_2d":
145
- self.fusion_model = DAF()
146
- elif self.fusion_type == "aff_2d":
147
- self.fusion_model = AFF(channels=embed_dim, type="2D")
148
- elif self.fusion_type == "iaff_2d":
149
- self.fusion_model = iAFF(channels=embed_dim, type="2D")
150
-
151
- def forward(self, x, longer_idx=None):
152
- if (self.enable_fusion) and (
153
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
154
- ):
155
- global_x = x[:, 0:1, :, :]
156
-
157
- # global processing
158
- B, C, H, W = global_x.shape
159
- assert (
160
- H == self.img_size[0] and W == self.img_size[1]
161
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
162
- global_x = self.proj(global_x)
163
- TW = global_x.size(-1)
164
- if len(longer_idx) > 0:
165
- # local processing
166
- local_x = x[longer_idx, 1:, :, :].contiguous()
167
- B, C, H, W = local_x.shape
168
- local_x = local_x.view(B * C, 1, H, W)
169
- local_x = self.mel_conv2d(local_x)
170
- local_x = local_x.view(
171
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
172
- )
173
- local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3)
174
- TB, TC, TH, _ = local_x.size()
175
- if local_x.size(-1) < TW:
176
- local_x = torch.cat(
177
- [
178
- local_x,
179
- torch.zeros(
180
- (TB, TC, TH, TW - local_x.size(-1)),
181
- device=global_x.device,
182
- ),
183
- ],
184
- dim=-1,
185
- )
186
- else:
187
- local_x = local_x[:, :, :, :TW]
188
-
189
- global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x)
190
- x = global_x
191
- else:
192
- B, C, H, W = x.shape
193
- assert (
194
- H == self.img_size[0] and W == self.img_size[1]
195
- ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
196
- x = self.proj(x)
197
-
198
- if self.flatten:
199
- x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
200
- x = self.norm(x)
201
- return x
202
-
203
-
204
- class Mlp(nn.Module):
205
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
206
-
207
- def __init__(
208
- self,
209
- in_features,
210
- hidden_features=None,
211
- out_features=None,
212
- act_layer=nn.GELU,
213
- drop=0.0,
214
- ):
215
- super().__init__()
216
- out_features = out_features or in_features
217
- hidden_features = hidden_features or in_features
218
- self.fc1 = nn.Linear(in_features, hidden_features)
219
- self.act = act_layer()
220
- self.fc2 = nn.Linear(hidden_features, out_features)
221
- self.drop = nn.Dropout(drop)
222
-
223
- def forward(self, x):
224
- x = self.fc1(x)
225
- x = self.act(x)
226
- x = self.drop(x)
227
- x = self.fc2(x)
228
- x = self.drop(x)
229
- return x
230
-
231
-
232
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
233
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
234
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
235
- def norm_cdf(x):
236
- # Computes standard normal cumulative distribution function
237
- return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
238
-
239
- if (mean < a - 2 * std) or (mean > b + 2 * std):
240
- warnings.warn(
241
- "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
242
- "The distribution of values may be incorrect.",
243
- stacklevel=2,
244
- )
245
-
246
- with torch.no_grad():
247
- # Values are generated by using a truncated uniform distribution and
248
- # then using the inverse CDF for the normal distribution.
249
- # Get upper and lower cdf values
250
- l = norm_cdf((a - mean) / std)
251
- u = norm_cdf((b - mean) / std)
252
-
253
- # Uniformly fill tensor with values from [l, u], then translate to
254
- # [2l-1, 2u-1].
255
- tensor.uniform_(2 * l - 1, 2 * u - 1)
256
-
257
- # Use inverse cdf transform for normal distribution to get truncated
258
- # standard normal
259
- tensor.erfinv_()
260
-
261
- # Transform to proper mean, std
262
- tensor.mul_(std * math.sqrt(2.0))
263
- tensor.add_(mean)
264
-
265
- # Clamp to ensure it's in the proper range
266
- tensor.clamp_(min=a, max=b)
267
- return tensor
268
-
269
-
270
- def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
271
- # type: (Tensor, float, float, float, float) -> Tensor
272
- r"""Fills the input Tensor with values drawn from a truncated
273
- normal distribution. The values are effectively drawn from the
274
- normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
275
- with values outside :math:`[a, b]` redrawn until they are within
276
- the bounds. The method used for generating the random values works
277
- best when :math:`a \leq \text{mean} \leq b`.
278
- Args:
279
- tensor: an n-dimensional `torch.Tensor`
280
- mean: the mean of the normal distribution
281
- std: the standard deviation of the normal distribution
282
- a: the minimum cutoff value
283
- b: the maximum cutoff value
284
- Examples:
285
- >>> w = torch.empty(3, 5)
286
- >>> nn.init.trunc_normal_(w)
287
- """
288
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
289
-
290
-
291
- def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
292
- fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
293
- if mode == "fan_in":
294
- denom = fan_in
295
- elif mode == "fan_out":
296
- denom = fan_out
297
- elif mode == "fan_avg":
298
- denom = (fan_in + fan_out) / 2
299
-
300
- variance = scale / denom
301
-
302
- if distribution == "truncated_normal":
303
- # constant is stddev of standard normal truncated to (-2, 2)
304
- trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
305
- elif distribution == "normal":
306
- tensor.normal_(std=math.sqrt(variance))
307
- elif distribution == "uniform":
308
- bound = math.sqrt(3 * variance)
309
- tensor.uniform_(-bound, bound)
310
- else:
311
- raise ValueError(f"invalid distribution {distribution}")
312
-
313
-
314
- def lecun_normal_(tensor):
315
- variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
316
-
317
-
318
- def window_partition(x, window_size):
319
- """
320
- Args:
321
- x: (B, H, W, C)
322
- window_size (int): window size
323
- Returns:
324
- windows: (num_windows*B, window_size, window_size, C)
325
- """
326
- B, H, W, C = x.shape
327
- x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
328
- windows = (
329
- x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
330
- )
331
- return windows
332
-
333
-
334
- def window_reverse(windows, window_size, H, W):
335
- """
336
- Args:
337
- windows: (num_windows*B, window_size, window_size, C)
338
- window_size (int): Window size
339
- H (int): Height of image
340
- W (int): Width of image
341
- Returns:
342
- x: (B, H, W, C)
343
- """
344
- B = int(windows.shape[0] / (H * W / window_size / window_size))
345
- x = windows.view(
346
- B, H // window_size, W // window_size, window_size, window_size, -1
347
- )
348
- x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
349
- return x
350
-
351
-
352
- class WindowAttention(nn.Module):
353
- r"""Window based multi-head self attention (W-MSA) module with relative position bias.
354
- It supports both of shifted and non-shifted window.
355
- Args:
356
- dim (int): Number of input channels.
357
- window_size (tuple[int]): The height and width of the window.
358
- num_heads (int): Number of attention heads.
359
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
360
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
361
- attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
362
- proj_drop (float, optional): Dropout ratio of output. Default: 0.0
363
- """
364
-
365
- def __init__(
366
- self,
367
- dim,
368
- window_size,
369
- num_heads,
370
- qkv_bias=True,
371
- qk_scale=None,
372
- attn_drop=0.0,
373
- proj_drop=0.0,
374
- ):
375
- super().__init__()
376
- self.dim = dim
377
- self.window_size = window_size # Wh, Ww
378
- self.num_heads = num_heads
379
- head_dim = dim // num_heads
380
- self.scale = qk_scale or head_dim**-0.5
381
-
382
- # define a parameter table of relative position bias
383
- self.relative_position_bias_table = nn.Parameter(
384
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
385
- ) # 2*Wh-1 * 2*Ww-1, nH
386
-
387
- # get pair-wise relative position index for each token inside the window
388
- coords_h = torch.arange(self.window_size[0])
389
- coords_w = torch.arange(self.window_size[1])
390
- coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
391
- coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
392
- relative_coords = (
393
- coords_flatten[:, :, None] - coords_flatten[:, None, :]
394
- ) # 2, Wh*Ww, Wh*Ww
395
- relative_coords = relative_coords.permute(
396
- 1, 2, 0
397
- ).contiguous() # Wh*Ww, Wh*Ww, 2
398
- relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
399
- relative_coords[:, :, 1] += self.window_size[1] - 1
400
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
401
- relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
402
- self.register_buffer("relative_position_index", relative_position_index)
403
-
404
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
405
- self.attn_drop = nn.Dropout(attn_drop)
406
- self.proj = nn.Linear(dim, dim)
407
- self.proj_drop = nn.Dropout(proj_drop)
408
-
409
- trunc_normal_(self.relative_position_bias_table, std=0.02)
410
- self.softmax = nn.Softmax(dim=-1)
411
-
412
- def forward(self, x, mask=None):
413
- """
414
- Args:
415
- x: input features with shape of (num_windows*B, N, C)
416
- mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
417
- """
418
- B_, N, C = x.shape
419
- qkv = (
420
- self.qkv(x)
421
- .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
422
- .permute(2, 0, 3, 1, 4)
423
- )
424
- q, k, v = (
425
- qkv[0],
426
- qkv[1],
427
- qkv[2],
428
- ) # make torchscript happy (cannot use tensor as tuple)
429
-
430
- q = q * self.scale
431
- attn = q @ k.transpose(-2, -1)
432
-
433
- relative_position_bias = self.relative_position_bias_table[
434
- self.relative_position_index.view(-1)
435
- ].view(
436
- self.window_size[0] * self.window_size[1],
437
- self.window_size[0] * self.window_size[1],
438
- -1,
439
- ) # Wh*Ww,Wh*Ww,nH
440
- relative_position_bias = relative_position_bias.permute(
441
- 2, 0, 1
442
- ).contiguous() # nH, Wh*Ww, Wh*Ww
443
- attn = attn + relative_position_bias.unsqueeze(0)
444
-
445
- if mask is not None:
446
- nW = mask.shape[0]
447
- attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
448
- 1
449
- ).unsqueeze(0)
450
- attn = attn.view(-1, self.num_heads, N, N)
451
- attn = self.softmax(attn)
452
- else:
453
- attn = self.softmax(attn)
454
-
455
- attn = self.attn_drop(attn)
456
-
457
- x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
458
- x = self.proj(x)
459
- x = self.proj_drop(x)
460
- return x, attn
461
-
462
- def extra_repr(self):
463
- return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
464
-
465
-
466
- # We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model
467
- class SwinTransformerBlock(nn.Module):
468
- r"""Swin Transformer Block.
469
- Args:
470
- dim (int): Number of input channels.
471
- input_resolution (tuple[int]): Input resulotion.
472
- num_heads (int): Number of attention heads.
473
- window_size (int): Window size.
474
- shift_size (int): Shift size for SW-MSA.
475
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
476
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
477
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
478
- drop (float, optional): Dropout rate. Default: 0.0
479
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
480
- drop_path (float, optional): Stochastic depth rate. Default: 0.0
481
- act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
482
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
483
- """
484
-
485
- def __init__(
486
- self,
487
- dim,
488
- input_resolution,
489
- num_heads,
490
- window_size=7,
491
- shift_size=0,
492
- mlp_ratio=4.0,
493
- qkv_bias=True,
494
- qk_scale=None,
495
- drop=0.0,
496
- attn_drop=0.0,
497
- drop_path=0.0,
498
- act_layer=nn.GELU,
499
- norm_layer=nn.LayerNorm,
500
- norm_before_mlp="ln",
501
- ):
502
- super().__init__()
503
- self.dim = dim
504
- self.input_resolution = input_resolution
505
- self.num_heads = num_heads
506
- self.window_size = window_size
507
- self.shift_size = shift_size
508
- self.mlp_ratio = mlp_ratio
509
- self.norm_before_mlp = norm_before_mlp
510
- if min(self.input_resolution) <= self.window_size:
511
- # if window size is larger than input resolution, we don't partition windows
512
- self.shift_size = 0
513
- self.window_size = min(self.input_resolution)
514
- assert (
515
- 0 <= self.shift_size < self.window_size
516
- ), "shift_size must in 0-window_size"
517
-
518
- self.norm1 = norm_layer(dim)
519
- self.attn = WindowAttention(
520
- dim,
521
- window_size=to_2tuple(self.window_size),
522
- num_heads=num_heads,
523
- qkv_bias=qkv_bias,
524
- qk_scale=qk_scale,
525
- attn_drop=attn_drop,
526
- proj_drop=drop,
527
- )
528
-
529
- self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
530
- if self.norm_before_mlp == "ln":
531
- self.norm2 = nn.LayerNorm(dim)
532
- elif self.norm_before_mlp == "bn":
533
- self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose(
534
- 1, 2
535
- )
536
- else:
537
- raise NotImplementedError
538
- mlp_hidden_dim = int(dim * mlp_ratio)
539
- self.mlp = Mlp(
540
- in_features=dim,
541
- hidden_features=mlp_hidden_dim,
542
- act_layer=act_layer,
543
- drop=drop,
544
- )
545
-
546
- if self.shift_size > 0:
547
- # calculate attention mask for SW-MSA
548
- H, W = self.input_resolution
549
- img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
550
- h_slices = (
551
- slice(0, -self.window_size),
552
- slice(-self.window_size, -self.shift_size),
553
- slice(-self.shift_size, None),
554
- )
555
- w_slices = (
556
- slice(0, -self.window_size),
557
- slice(-self.window_size, -self.shift_size),
558
- slice(-self.shift_size, None),
559
- )
560
- cnt = 0
561
- for h in h_slices:
562
- for w in w_slices:
563
- img_mask[:, h, w, :] = cnt
564
- cnt += 1
565
-
566
- mask_windows = window_partition(
567
- img_mask, self.window_size
568
- ) # nW, window_size, window_size, 1
569
- mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
570
- attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
571
- attn_mask = attn_mask.masked_fill(
572
- attn_mask != 0, float(-100.0)
573
- ).masked_fill(attn_mask == 0, float(0.0))
574
- else:
575
- attn_mask = None
576
-
577
- self.register_buffer("attn_mask", attn_mask)
578
-
579
- def forward(self, x):
580
- # pdb.set_trace()
581
- H, W = self.input_resolution
582
- # print("H: ", H)
583
- # print("W: ", W)
584
- # pdb.set_trace()
585
- B, L, C = x.shape
586
- # assert L == H * W, "input feature has wrong size"
587
-
588
- shortcut = x
589
- x = self.norm1(x)
590
- x = x.view(B, H, W, C)
591
-
592
- # cyclic shift
593
- if self.shift_size > 0:
594
- shifted_x = torch.roll(
595
- x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
596
- )
597
- else:
598
- shifted_x = x
599
-
600
- # partition windows
601
- x_windows = window_partition(
602
- shifted_x, self.window_size
603
- ) # nW*B, window_size, window_size, C
604
- x_windows = x_windows.view(
605
- -1, self.window_size * self.window_size, C
606
- ) # nW*B, window_size*window_size, C
607
-
608
- # W-MSA/SW-MSA
609
- attn_windows, attn = self.attn(
610
- x_windows, mask=self.attn_mask
611
- ) # nW*B, window_size*window_size, C
612
-
613
- # merge windows
614
- attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
615
- shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
616
-
617
- # reverse cyclic shift
618
- if self.shift_size > 0:
619
- x = torch.roll(
620
- shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
621
- )
622
- else:
623
- x = shifted_x
624
- x = x.view(B, H * W, C)
625
-
626
- # FFN
627
- x = shortcut + self.drop_path(x)
628
- x = x + self.drop_path(self.mlp(self.norm2(x)))
629
-
630
- return x, attn
631
-
632
- def extra_repr(self):
633
- return (
634
- f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
635
- f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
636
- )
637
-
638
-
639
- class PatchMerging(nn.Module):
640
- r"""Patch Merging Layer.
641
- Args:
642
- input_resolution (tuple[int]): Resolution of input feature.
643
- dim (int): Number of input channels.
644
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
645
- """
646
-
647
- def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
648
- super().__init__()
649
- self.input_resolution = input_resolution
650
- self.dim = dim
651
- self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
652
- self.norm = norm_layer(4 * dim)
653
-
654
- def forward(self, x):
655
- """
656
- x: B, H*W, C
657
- """
658
- H, W = self.input_resolution
659
- B, L, C = x.shape
660
- assert L == H * W, "input feature has wrong size"
661
- assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
662
-
663
- x = x.view(B, H, W, C)
664
-
665
- x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
666
- x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
667
- x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
668
- x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
669
- x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
670
- x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
671
-
672
- x = self.norm(x)
673
- x = self.reduction(x)
674
-
675
- return x
676
-
677
- def extra_repr(self):
678
- return f"input_resolution={self.input_resolution}, dim={self.dim}"
679
-
680
-
681
- class BasicLayer(nn.Module):
682
- """A basic Swin Transformer layer for one stage.
683
- Args:
684
- dim (int): Number of input channels.
685
- input_resolution (tuple[int]): Input resolution.
686
- depth (int): Number of blocks.
687
- num_heads (int): Number of attention heads.
688
- window_size (int): Local window size.
689
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
690
- qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
691
- qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
692
- drop (float, optional): Dropout rate. Default: 0.0
693
- attn_drop (float, optional): Attention dropout rate. Default: 0.0
694
- drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
695
- norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
696
- downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
697
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
698
- """
699
-
700
- def __init__(
701
- self,
702
- dim,
703
- input_resolution,
704
- depth,
705
- num_heads,
706
- window_size,
707
- mlp_ratio=4.0,
708
- qkv_bias=True,
709
- qk_scale=None,
710
- drop=0.0,
711
- attn_drop=0.0,
712
- drop_path=0.0,
713
- norm_layer=nn.LayerNorm,
714
- downsample=None,
715
- use_checkpoint=False,
716
- norm_before_mlp="ln",
717
- ):
718
- super().__init__()
719
- self.dim = dim
720
- self.input_resolution = input_resolution
721
- self.depth = depth
722
- self.use_checkpoint = use_checkpoint
723
-
724
- # build blocks
725
- self.blocks = nn.ModuleList(
726
- [
727
- SwinTransformerBlock(
728
- dim=dim,
729
- input_resolution=input_resolution,
730
- num_heads=num_heads,
731
- window_size=window_size,
732
- shift_size=0 if (i % 2 == 0) else window_size // 2,
733
- mlp_ratio=mlp_ratio,
734
- qkv_bias=qkv_bias,
735
- qk_scale=qk_scale,
736
- drop=drop,
737
- attn_drop=attn_drop,
738
- drop_path=drop_path[i]
739
- if isinstance(drop_path, list)
740
- else drop_path,
741
- norm_layer=norm_layer,
742
- norm_before_mlp=norm_before_mlp,
743
- )
744
- for i in range(depth)
745
- ]
746
- )
747
-
748
- # patch merging layer
749
- if downsample is not None:
750
- self.downsample = downsample(
751
- input_resolution, dim=dim, norm_layer=norm_layer
752
- )
753
- else:
754
- self.downsample = None
755
-
756
- def forward(self, x):
757
- attns = []
758
- for blk in self.blocks:
759
- if self.use_checkpoint:
760
- x = checkpoint.checkpoint(blk, x)
761
- else:
762
- x, attn = blk(x)
763
- if not self.training:
764
- attns.append(attn.unsqueeze(0))
765
- if self.downsample is not None:
766
- x = self.downsample(x)
767
- if not self.training:
768
- attn = torch.cat(attns, dim=0)
769
- attn = torch.mean(attn, dim=0)
770
- return x, attn
771
-
772
- def extra_repr(self):
773
- return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
774
-
775
-
776
- # The Core of HTSAT
777
- class HTSAT_Swin_Transformer(nn.Module):
778
- r"""HTSAT based on the Swin Transformer
779
- Args:
780
- spec_size (int | tuple(int)): Input Spectrogram size. Default 256
781
- patch_size (int | tuple(int)): Patch size. Default: 4
782
- path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4
783
- in_chans (int): Number of input image channels. Default: 1 (mono)
784
- num_classes (int): Number of classes for classification head. Default: 527
785
- embed_dim (int): Patch embedding dimension. Default: 96
786
- depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer.
787
- num_heads (tuple(int)): Number of attention heads in different layers.
788
- window_size (int): Window size. Default: 8
789
- mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
790
- qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
791
- qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
792
- drop_rate (float): Dropout rate. Default: 0
793
- attn_drop_rate (float): Attention dropout rate. Default: 0
794
- drop_path_rate (float): Stochastic depth rate. Default: 0.1
795
- norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
796
- ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
797
- patch_norm (bool): If True, add normalization after patch embedding. Default: True
798
- use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
799
- config (module): The configuration Module from config.py
800
- """
801
-
802
- def __init__(
803
- self,
804
- spec_size=256,
805
- patch_size=4,
806
- patch_stride=(4, 4),
807
- in_chans=1,
808
- num_classes=527,
809
- embed_dim=96,
810
- depths=[2, 2, 6, 2],
811
- num_heads=[4, 8, 16, 32],
812
- window_size=8,
813
- mlp_ratio=4.0,
814
- qkv_bias=True,
815
- qk_scale=None,
816
- drop_rate=0.0,
817
- attn_drop_rate=0.0,
818
- drop_path_rate=0.1,
819
- norm_layer=nn.LayerNorm,
820
- ape=False,
821
- patch_norm=True,
822
- use_checkpoint=False,
823
- norm_before_mlp="ln",
824
- config=None,
825
- enable_fusion=False,
826
- fusion_type="None",
827
- **kwargs,
828
- ):
829
- super(HTSAT_Swin_Transformer, self).__init__()
830
-
831
- self.config = config
832
- self.spec_size = spec_size
833
- self.patch_stride = patch_stride
834
- self.patch_size = patch_size
835
- self.window_size = window_size
836
- self.embed_dim = embed_dim
837
- self.depths = depths
838
- self.ape = ape
839
- self.in_chans = in_chans
840
- self.num_classes = num_classes
841
- self.num_heads = num_heads
842
- self.num_layers = len(self.depths)
843
- self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1))
844
-
845
- self.drop_rate = drop_rate
846
- self.attn_drop_rate = attn_drop_rate
847
- self.drop_path_rate = drop_path_rate
848
-
849
- self.qkv_bias = qkv_bias
850
- self.qk_scale = None
851
-
852
- self.patch_norm = patch_norm
853
- self.norm_layer = norm_layer if self.patch_norm else None
854
- self.norm_before_mlp = norm_before_mlp
855
- self.mlp_ratio = mlp_ratio
856
-
857
- self.use_checkpoint = use_checkpoint
858
-
859
- self.enable_fusion = enable_fusion
860
- self.fusion_type = fusion_type
861
-
862
- # process mel-spec ; used only once
863
- self.freq_ratio = self.spec_size // self.config.mel_bins
864
- window = "hann"
865
- center = True
866
- pad_mode = "reflect"
867
- ref = 1.0
868
- amin = 1e-10
869
- top_db = None
870
- self.interpolate_ratio = 32 # Downsampled ratio
871
- # Spectrogram extractor
872
- self.spectrogram_extractor = Spectrogram(
873
- n_fft=config.window_size,
874
- hop_length=config.hop_size,
875
- win_length=config.window_size,
876
- window=window,
877
- center=center,
878
- pad_mode=pad_mode,
879
- freeze_parameters=True,
880
- )
881
- # Logmel feature extractor
882
- self.logmel_extractor = LogmelFilterBank(
883
- sr=config.sample_rate,
884
- n_fft=config.window_size,
885
- n_mels=config.mel_bins,
886
- fmin=config.fmin,
887
- fmax=config.fmax,
888
- ref=ref,
889
- amin=amin,
890
- top_db=top_db,
891
- freeze_parameters=True,
892
- )
893
- # Spec augmenter
894
- self.spec_augmenter = SpecAugmentation(
895
- time_drop_width=64,
896
- time_stripes_num=2,
897
- freq_drop_width=8,
898
- freq_stripes_num=2,
899
- ) # 2 2
900
- self.bn0 = nn.BatchNorm2d(self.config.mel_bins)
901
-
902
- # split spctrogram into non-overlapping patches
903
- self.patch_embed = PatchEmbed(
904
- img_size=self.spec_size,
905
- patch_size=self.patch_size,
906
- in_chans=self.in_chans,
907
- embed_dim=self.embed_dim,
908
- norm_layer=self.norm_layer,
909
- patch_stride=patch_stride,
910
- enable_fusion=self.enable_fusion,
911
- fusion_type=self.fusion_type,
912
- )
913
-
914
- num_patches = self.patch_embed.num_patches
915
- patches_resolution = self.patch_embed.grid_size
916
- self.patches_resolution = patches_resolution
917
-
918
- # absolute position embedding
919
- if self.ape:
920
- self.absolute_pos_embed = nn.Parameter(
921
- torch.zeros(1, num_patches, self.embed_dim)
922
- )
923
- trunc_normal_(self.absolute_pos_embed, std=0.02)
924
-
925
- self.pos_drop = nn.Dropout(p=self.drop_rate)
926
-
927
- # stochastic depth
928
- dpr = [
929
- x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths))
930
- ] # stochastic depth decay rule
931
-
932
- # build layers
933
- self.layers = nn.ModuleList()
934
- for i_layer in range(self.num_layers):
935
- layer = BasicLayer(
936
- dim=int(self.embed_dim * 2**i_layer),
937
- input_resolution=(
938
- patches_resolution[0] // (2**i_layer),
939
- patches_resolution[1] // (2**i_layer),
940
- ),
941
- depth=self.depths[i_layer],
942
- num_heads=self.num_heads[i_layer],
943
- window_size=self.window_size,
944
- mlp_ratio=self.mlp_ratio,
945
- qkv_bias=self.qkv_bias,
946
- qk_scale=self.qk_scale,
947
- drop=self.drop_rate,
948
- attn_drop=self.attn_drop_rate,
949
- drop_path=dpr[
950
- sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1])
951
- ],
952
- norm_layer=self.norm_layer,
953
- downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
954
- use_checkpoint=use_checkpoint,
955
- norm_before_mlp=self.norm_before_mlp,
956
- )
957
- self.layers.append(layer)
958
-
959
- self.norm = self.norm_layer(self.num_features)
960
- self.avgpool = nn.AdaptiveAvgPool1d(1)
961
- self.maxpool = nn.AdaptiveMaxPool1d(1)
962
-
963
- SF = (
964
- self.spec_size
965
- // (2 ** (len(self.depths) - 1))
966
- // self.patch_stride[0]
967
- // self.freq_ratio
968
- )
969
- self.tscam_conv = nn.Conv2d(
970
- in_channels=self.num_features,
971
- out_channels=self.num_classes,
972
- kernel_size=(SF, 3),
973
- padding=(0, 1),
974
- )
975
- self.head = nn.Linear(num_classes, num_classes)
976
-
977
- if (self.enable_fusion) and (
978
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
979
- ):
980
- self.mel_conv1d = nn.Sequential(
981
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
982
- nn.BatchNorm1d(64),
983
- )
984
- if self.fusion_type == "daf_1d":
985
- self.fusion_model = DAF()
986
- elif self.fusion_type == "aff_1d":
987
- self.fusion_model = AFF(channels=64, type="1D")
988
- elif self.fusion_type == "iaff_1d":
989
- self.fusion_model = iAFF(channels=64, type="1D")
990
-
991
- self.apply(self._init_weights)
992
-
993
- def _init_weights(self, m):
994
- if isinstance(m, nn.Linear):
995
- trunc_normal_(m.weight, std=0.02)
996
- if isinstance(m, nn.Linear) and m.bias is not None:
997
- nn.init.constant_(m.bias, 0)
998
- elif isinstance(m, nn.LayerNorm):
999
- nn.init.constant_(m.bias, 0)
1000
- nn.init.constant_(m.weight, 1.0)
1001
-
1002
- @torch.jit.ignore
1003
- def no_weight_decay(self):
1004
- return {"absolute_pos_embed"}
1005
-
1006
- @torch.jit.ignore
1007
- def no_weight_decay_keywords(self):
1008
- return {"relative_position_bias_table"}
1009
-
1010
- def forward_features(self, x, longer_idx=None):
1011
- # A deprecated optimization for using a hierarchical output from different blocks
1012
-
1013
- frames_num = x.shape[2]
1014
- x = self.patch_embed(x, longer_idx=longer_idx)
1015
- if self.ape:
1016
- x = x + self.absolute_pos_embed
1017
- x = self.pos_drop(x)
1018
- for i, layer in enumerate(self.layers):
1019
- x, attn = layer(x)
1020
- # for x
1021
- x = self.norm(x)
1022
- B, N, C = x.shape
1023
- SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0]
1024
- ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1]
1025
- x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST)
1026
- B, C, F, T = x.shape
1027
- # group 2D CNN
1028
- c_freq_bin = F // self.freq_ratio
1029
- x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T)
1030
- x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1)
1031
- # get latent_output
1032
- fine_grained_latent_output = torch.mean(x, dim=2)
1033
- fine_grained_latent_output = interpolate(
1034
- fine_grained_latent_output.permute(0, 2, 1).contiguous(),
1035
- 8 * self.patch_stride[1],
1036
- )
1037
-
1038
- latent_output = self.avgpool(torch.flatten(x, 2))
1039
- latent_output = torch.flatten(latent_output, 1)
1040
-
1041
- # display the attention map, if needed
1042
-
1043
- x = self.tscam_conv(x)
1044
- x = torch.flatten(x, 2) # B, C, T
1045
-
1046
- fpx = interpolate(
1047
- torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1]
1048
- )
1049
-
1050
- x = self.avgpool(x)
1051
- x = torch.flatten(x, 1)
1052
-
1053
- output_dict = {
1054
- "framewise_output": fpx, # already sigmoided
1055
- "clipwise_output": torch.sigmoid(x),
1056
- "fine_grained_embedding": fine_grained_latent_output,
1057
- "embedding": latent_output,
1058
- }
1059
-
1060
- return output_dict
1061
-
1062
- def crop_wav(self, x, crop_size, spe_pos=None):
1063
- time_steps = x.shape[2]
1064
- tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device)
1065
- for i in range(len(x)):
1066
- if spe_pos is None:
1067
- crop_pos = random.randint(0, time_steps - crop_size - 1)
1068
- else:
1069
- crop_pos = spe_pos
1070
- tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :]
1071
- return tx
1072
-
1073
- # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model
1074
- def reshape_wav2img(self, x):
1075
- B, C, T, F = x.shape
1076
- target_T = int(self.spec_size * self.freq_ratio)
1077
- target_F = self.spec_size // self.freq_ratio
1078
- assert (
1079
- T <= target_T and F <= target_F
1080
- ), "the wav size should less than or equal to the swin input size"
1081
- # to avoid bicubic zero error
1082
- if T < target_T:
1083
- x = nn.functional.interpolate(
1084
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1085
- )
1086
- if F < target_F:
1087
- x = nn.functional.interpolate(
1088
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1089
- )
1090
- x = x.permute(0, 1, 3, 2).contiguous()
1091
- x = x.reshape(
1092
- x.shape[0],
1093
- x.shape[1],
1094
- x.shape[2],
1095
- self.freq_ratio,
1096
- x.shape[3] // self.freq_ratio,
1097
- )
1098
- # print(x.shape)
1099
- x = x.permute(0, 1, 3, 2, 4).contiguous()
1100
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4])
1101
- return x
1102
-
1103
- # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model
1104
- def repeat_wat2img(self, x, cur_pos):
1105
- B, C, T, F = x.shape
1106
- target_T = int(self.spec_size * self.freq_ratio)
1107
- target_F = self.spec_size // self.freq_ratio
1108
- assert (
1109
- T <= target_T and F <= target_F
1110
- ), "the wav size should less than or equal to the swin input size"
1111
- # to avoid bicubic zero error
1112
- if T < target_T:
1113
- x = nn.functional.interpolate(
1114
- x, (target_T, x.shape[3]), mode="bicubic", align_corners=True
1115
- )
1116
- if F < target_F:
1117
- x = nn.functional.interpolate(
1118
- x, (x.shape[2], target_F), mode="bicubic", align_corners=True
1119
- )
1120
- x = x.permute(0, 1, 3, 2).contiguous() # B C F T
1121
- x = x[:, :, :, cur_pos : cur_pos + self.spec_size]
1122
- x = x.repeat(repeats=(1, 1, 4, 1))
1123
- return x
1124
-
1125
- def forward(
1126
- self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None
1127
- ): # out_feat_keys: List[str] = None):
1128
- if self.enable_fusion and x["longer"].sum() == 0:
1129
- # if no audio is longer than 10s, then randomly select one audio to be longer
1130
- x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True
1131
-
1132
- if not self.enable_fusion:
1133
- x = x["waveform"].to(device=device, non_blocking=True)
1134
- x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins)
1135
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
1136
- x = x.transpose(1, 3)
1137
- x = self.bn0(x)
1138
- x = x.transpose(1, 3)
1139
- if self.training:
1140
- x = self.spec_augmenter(x)
1141
-
1142
- if self.training and mixup_lambda is not None:
1143
- x = do_mixup(x, mixup_lambda)
1144
-
1145
- x = self.reshape_wav2img(x)
1146
- output_dict = self.forward_features(x)
1147
- else:
1148
- longer_list = x["longer"].to(device=device, non_blocking=True)
1149
- x = x["mel_fusion"].to(device=device, non_blocking=True)
1150
- x = x.transpose(1, 3)
1151
- x = self.bn0(x)
1152
- x = x.transpose(1, 3)
1153
- longer_list_idx = torch.where(longer_list)[0]
1154
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
1155
- new_x = x[:, 0:1, :, :].clone().contiguous()
1156
- if len(longer_list_idx) > 0:
1157
- # local processing
1158
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
1159
- FB, FC, FT, FF = fusion_x_local.size()
1160
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
1161
- fusion_x_local = torch.permute(
1162
- fusion_x_local, (0, 2, 1)
1163
- ).contiguous()
1164
- fusion_x_local = self.mel_conv1d(fusion_x_local)
1165
- fusion_x_local = fusion_x_local.view(
1166
- FB, FC, FF, fusion_x_local.size(-1)
1167
- )
1168
- fusion_x_local = (
1169
- torch.permute(fusion_x_local, (0, 2, 1, 3))
1170
- .contiguous()
1171
- .flatten(2)
1172
- )
1173
- if fusion_x_local.size(-1) < FT:
1174
- fusion_x_local = torch.cat(
1175
- [
1176
- fusion_x_local,
1177
- torch.zeros(
1178
- (FB, FF, FT - fusion_x_local.size(-1)),
1179
- device=device,
1180
- ),
1181
- ],
1182
- dim=-1,
1183
- )
1184
- else:
1185
- fusion_x_local = fusion_x_local[:, :, :FT]
1186
- # 1D fusion
1187
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
1188
- new_x[longer_list_idx] = self.fusion_model(
1189
- new_x[longer_list_idx], fusion_x_local
1190
- )
1191
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
1192
- else:
1193
- x = new_x
1194
-
1195
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
1196
- x = x # no change
1197
-
1198
- if self.training:
1199
- x = self.spec_augmenter(x)
1200
- if self.training and mixup_lambda is not None:
1201
- x = do_mixup(x, mixup_lambda)
1202
-
1203
- x = self.reshape_wav2img(x)
1204
- output_dict = self.forward_features(x, longer_idx=longer_list_idx)
1205
-
1206
- # if infer_mode:
1207
- # # in infer mode. we need to handle different length audio input
1208
- # frame_num = x.shape[2]
1209
- # target_T = int(self.spec_size * self.freq_ratio)
1210
- # repeat_ratio = math.floor(target_T / frame_num)
1211
- # x = x.repeat(repeats=(1,1,repeat_ratio,1))
1212
- # x = self.reshape_wav2img(x)
1213
- # output_dict = self.forward_features(x)
1214
- # else:
1215
- # if x.shape[2] > self.freq_ratio * self.spec_size:
1216
- # if self.training:
1217
- # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size)
1218
- # x = self.reshape_wav2img(x)
1219
- # output_dict = self.forward_features(x)
1220
- # else:
1221
- # # Change: Hard code here
1222
- # overlap_size = (x.shape[2] - 1) // 4
1223
- # output_dicts = []
1224
- # crop_size = (x.shape[2] - 1) // 2
1225
- # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size):
1226
- # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos)
1227
- # tx = self.reshape_wav2img(tx)
1228
- # output_dicts.append(self.forward_features(tx))
1229
- # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device)
1230
- # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device)
1231
- # for d in output_dicts:
1232
- # clipwise_output += d["clipwise_output"]
1233
- # framewise_output += d["framewise_output"]
1234
- # clipwise_output = clipwise_output / len(output_dicts)
1235
- # framewise_output = framewise_output / len(output_dicts)
1236
- # output_dict = {
1237
- # 'framewise_output': framewise_output,
1238
- # 'clipwise_output': clipwise_output
1239
- # }
1240
- # else: # this part is typically used, and most easy one
1241
- # x = self.reshape_wav2img(x)
1242
- # output_dict = self.forward_features(x)
1243
- # x = self.head(x)
1244
-
1245
- # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T
1246
-
1247
- return output_dict
1248
-
1249
-
1250
- def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"):
1251
- try:
1252
- assert audio_cfg.model_name in [
1253
- "tiny",
1254
- "base",
1255
- "large",
1256
- ], "model name for HTS-AT is wrong!"
1257
- if audio_cfg.model_name == "tiny":
1258
- model = HTSAT_Swin_Transformer(
1259
- spec_size=256,
1260
- patch_size=4,
1261
- patch_stride=(4, 4),
1262
- num_classes=audio_cfg.class_num,
1263
- embed_dim=96,
1264
- depths=[2, 2, 6, 2],
1265
- num_heads=[4, 8, 16, 32],
1266
- window_size=8,
1267
- config=audio_cfg,
1268
- enable_fusion=enable_fusion,
1269
- fusion_type=fusion_type,
1270
- )
1271
- elif audio_cfg.model_name == "base":
1272
- model = HTSAT_Swin_Transformer(
1273
- spec_size=256,
1274
- patch_size=4,
1275
- patch_stride=(4, 4),
1276
- num_classes=audio_cfg.class_num,
1277
- embed_dim=128,
1278
- depths=[2, 2, 12, 2],
1279
- num_heads=[4, 8, 16, 32],
1280
- window_size=8,
1281
- config=audio_cfg,
1282
- enable_fusion=enable_fusion,
1283
- fusion_type=fusion_type,
1284
- )
1285
- elif audio_cfg.model_name == "large":
1286
- model = HTSAT_Swin_Transformer(
1287
- spec_size=256,
1288
- patch_size=4,
1289
- patch_stride=(4, 4),
1290
- num_classes=audio_cfg.class_num,
1291
- embed_dim=256,
1292
- depths=[2, 2, 12, 2],
1293
- num_heads=[4, 8, 16, 32],
1294
- window_size=8,
1295
- config=audio_cfg,
1296
- enable_fusion=enable_fusion,
1297
- fusion_type=fusion_type,
1298
- )
1299
-
1300
- return model
1301
- except:
1302
- raise RuntimeError(
1303
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
1304
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/loss.py DELETED
@@ -1,397 +0,0 @@
1
- import torch
2
- import torch.distributed.nn
3
- from torch import distributed as dist, nn as nn
4
- from torch.nn import functional as F
5
- import numpy as np
6
- from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
7
-
8
- try:
9
- import horovod.torch as hvd
10
- except ImportError:
11
- hvd = None
12
-
13
-
14
- def gather_features(
15
- audio_features,
16
- text_features,
17
- audio_features_mlp=None,
18
- text_features_mlp=None,
19
- local_loss=False,
20
- gather_with_grad=False,
21
- rank=0,
22
- world_size=1,
23
- use_horovod=False,
24
- mlp_loss=False,
25
- ):
26
- if use_horovod:
27
- assert hvd is not None, "Please install horovod"
28
- if gather_with_grad:
29
- all_audio_features = hvd.allgather(audio_features)
30
- all_text_features = hvd.allgather(text_features)
31
- if mlp_loss:
32
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
33
- all_text_features_mlp = hvd.allgather(text_features_mlp)
34
- else:
35
- with torch.no_grad():
36
- all_audio_features = hvd.allgather(audio_features)
37
- all_text_features = hvd.allgather(text_features)
38
- if mlp_loss:
39
- all_audio_features_mlp = hvd.allgather(audio_features_mlp)
40
- all_text_features_mlp = hvd.allgather(text_features_mlp)
41
- if not local_loss:
42
- # ensure grads for local rank when all_* features don't have a gradient
43
- gathered_audio_features = list(
44
- all_audio_features.chunk(world_size, dim=0)
45
- )
46
- gathered_text_features = list(
47
- all_text_features.chunk(world_size, dim=0)
48
- )
49
- gathered_audio_features[rank] = audio_features
50
- gathered_text_features[rank] = text_features
51
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
52
- all_text_features = torch.cat(gathered_text_features, dim=0)
53
- if mlp_loss:
54
- gathered_audio_features_mlp = list(
55
- all_audio_features_mlp.chunk(world_size, dim=0)
56
- )
57
- gathered_text_features_mlp = list(
58
- all_text_features_mlp.chunk(world_size, dim=0)
59
- )
60
- gathered_audio_features_mlp[rank] = audio_features_mlp
61
- gathered_text_features_mlp[rank] = text_features_mlp
62
- all_audio_features_mlp = torch.cat(
63
- gathered_audio_features_mlp, dim=0
64
- )
65
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
66
- else:
67
- # We gather tensors from all gpus
68
- if gather_with_grad:
69
- all_audio_features = torch.cat(
70
- torch.distributed.nn.all_gather(audio_features), dim=0
71
- )
72
- all_text_features = torch.cat(
73
- torch.distributed.nn.all_gather(text_features), dim=0
74
- )
75
- if mlp_loss:
76
- all_audio_features_mlp = torch.cat(
77
- torch.distributed.nn.all_gather(audio_features_mlp), dim=0
78
- )
79
- all_text_features_mlp = torch.cat(
80
- torch.distributed.nn.all_gather(text_features_mlp), dim=0
81
- )
82
- else:
83
- gathered_audio_features = [
84
- torch.zeros_like(audio_features) for _ in range(world_size)
85
- ]
86
- gathered_text_features = [
87
- torch.zeros_like(text_features) for _ in range(world_size)
88
- ]
89
- dist.all_gather(gathered_audio_features, audio_features)
90
- dist.all_gather(gathered_text_features, text_features)
91
- if mlp_loss:
92
- gathered_audio_features_mlp = [
93
- torch.zeros_like(audio_features_mlp) for _ in range(world_size)
94
- ]
95
- gathered_text_features_mlp = [
96
- torch.zeros_like(text_features_mlp) for _ in range(world_size)
97
- ]
98
- dist.all_gather(gathered_audio_features_mlp, audio_features_mlp)
99
- dist.all_gather(gathered_text_features_mlp, text_features_mlp)
100
- if not local_loss:
101
- # ensure grads for local rank when all_* features don't have a gradient
102
- gathered_audio_features[rank] = audio_features
103
- gathered_text_features[rank] = text_features
104
- if mlp_loss:
105
- gathered_audio_features_mlp[rank] = audio_features_mlp
106
- gathered_text_features_mlp[rank] = text_features_mlp
107
-
108
- all_audio_features = torch.cat(gathered_audio_features, dim=0)
109
- all_text_features = torch.cat(gathered_text_features, dim=0)
110
- if mlp_loss:
111
- all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0)
112
- all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0)
113
- if mlp_loss:
114
- return (
115
- all_audio_features,
116
- all_text_features,
117
- all_audio_features_mlp,
118
- all_text_features_mlp,
119
- )
120
- else:
121
- return all_audio_features, all_text_features
122
-
123
-
124
- class ClipLoss(nn.Module):
125
- def __init__(
126
- self,
127
- local_loss=False,
128
- gather_with_grad=False,
129
- cache_labels=False,
130
- rank=0,
131
- world_size=1,
132
- use_horovod=False,
133
- mlp_loss=False,
134
- weight_loss_kappa=0,
135
- ):
136
- super().__init__()
137
- self.local_loss = local_loss
138
- self.gather_with_grad = gather_with_grad
139
- self.cache_labels = cache_labels
140
- self.rank = rank
141
- self.world_size = world_size
142
- self.use_horovod = use_horovod
143
- self.mlp_loss = mlp_loss
144
- self.weighted_loss = bool(weight_loss_kappa != 0)
145
- self.weight_loss_kappa = weight_loss_kappa
146
- # cache state
147
- self.prev_num_logits = 0
148
- self.labels = {}
149
-
150
- def forward(
151
- self,
152
- audio_features,
153
- text_features,
154
- logit_scale_a,
155
- logit_scale_t=None,
156
- audio_features_mlp=None,
157
- text_features_mlp=None,
158
- ):
159
- device = audio_features.device
160
- if self.mlp_loss:
161
- if self.world_size > 1:
162
- (
163
- all_audio_features,
164
- all_text_features,
165
- all_audio_features_mlp,
166
- all_text_features_mlp,
167
- ) = gather_features(
168
- audio_features=audio_features,
169
- text_features=text_features,
170
- audio_features_mlp=audio_features_mlp,
171
- text_features_mlp=text_features_mlp,
172
- local_loss=self.local_loss,
173
- gather_with_grad=self.gather_with_grad,
174
- rank=self.rank,
175
- world_size=self.world_size,
176
- use_horovod=self.use_horovod,
177
- mlp_loss=self.mlp_loss,
178
- )
179
- if self.local_loss:
180
- a_logits_per_audio = (
181
- logit_scale_a * audio_features @ all_text_features_mlp.T
182
- )
183
- a_logits_per_text = (
184
- logit_scale_a * text_features_mlp @ all_audio_features.T
185
- )
186
- t_logits_per_audio = (
187
- logit_scale_t * audio_features_mlp @ all_text_features.T
188
- )
189
- t_logits_per_text = (
190
- logit_scale_t * text_features @ all_audio_features_mlp.T
191
- )
192
- else:
193
- a_logits_per_audio = (
194
- logit_scale_a * all_audio_features @ all_text_features_mlp.T
195
- )
196
- a_logits_per_text = a_logits_per_audio.T
197
- t_logits_per_audio = (
198
- logit_scale_t * all_audio_features_mlp @ all_text_features.T
199
- )
200
- t_logits_per_text = t_logits_per_audio.T
201
- else:
202
- a_logits_per_audio = (
203
- logit_scale_a * audio_features @ text_features_mlp.T
204
- )
205
- a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T
206
- t_logits_per_audio = (
207
- logit_scale_t * audio_features_mlp @ text_features.T
208
- )
209
- t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T
210
-
211
- # calculated ground-truth and cache if enabled
212
- num_logits = a_logits_per_audio.shape[0]
213
- if self.prev_num_logits != num_logits or device not in self.labels:
214
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
215
- if self.world_size > 1 and self.local_loss:
216
- labels = labels + num_logits * self.rank
217
- if self.cache_labels:
218
- self.labels[device] = labels
219
- self.prev_num_logits = num_logits
220
- else:
221
- labels = self.labels[device]
222
-
223
- if not self.weighted_loss:
224
- total_loss = (
225
- F.cross_entropy(a_logits_per_audio, labels)
226
- + F.cross_entropy(a_logits_per_text, labels)
227
- + F.cross_entropy(t_logits_per_audio, labels)
228
- + F.cross_entropy(t_logits_per_text, labels)
229
- ) / 4
230
- else:
231
- audio_weight = (audio_features @ audio_features.T).detach()
232
- audio_weight = (
233
- torch.exp(
234
- torch.sum(audio_weight, axis=1)
235
- / (self.weight_loss_kappa * len(audio_weight))
236
- )
237
- ).detach()
238
- text_weight = (text_features @ text_features.T).detach()
239
- text_weight = (
240
- torch.exp(
241
- torch.sum(text_weight, axis=1)
242
- / (self.weight_loss_kappa * len(text_features))
243
- )
244
- ).detach()
245
- total_loss = (
246
- F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight)
247
- + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight)
248
- + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight)
249
- + F.cross_entropy(t_logits_per_text, labels, weight=text_weight)
250
- ) / 4
251
- else:
252
- if self.world_size > 1:
253
- all_audio_features, all_text_features = gather_features(
254
- audio_features=audio_features,
255
- text_features=text_features,
256
- local_loss=self.local_loss,
257
- gather_with_grad=self.gather_with_grad,
258
- rank=self.rank,
259
- world_size=self.world_size,
260
- use_horovod=self.use_horovod,
261
- mlp_loss=self.mlp_loss,
262
- )
263
-
264
- if self.local_loss:
265
- logits_per_audio = (
266
- logit_scale_a * audio_features @ all_text_features.T
267
- )
268
- logits_per_text = (
269
- logit_scale_a * text_features @ all_audio_features.T
270
- )
271
- else:
272
- logits_per_audio = (
273
- logit_scale_a * all_audio_features @ all_text_features.T
274
- )
275
- logits_per_text = logits_per_audio.T
276
- else:
277
- logits_per_audio = logit_scale_a * audio_features @ text_features.T
278
- logits_per_text = logit_scale_a * text_features @ audio_features.T
279
-
280
- # calculated ground-truth and cache if enabled
281
- num_logits = logits_per_audio.shape[0]
282
- if self.prev_num_logits != num_logits or device not in self.labels:
283
- labels = torch.arange(num_logits, device=device, dtype=torch.long)
284
- if self.world_size > 1 and self.local_loss:
285
- labels = labels + num_logits * self.rank
286
- if self.cache_labels:
287
- self.labels[device] = labels
288
- self.prev_num_logits = num_logits
289
- else:
290
- labels = self.labels[device]
291
- if not self.weighted_loss:
292
- total_loss = (
293
- F.cross_entropy(logits_per_audio, labels)
294
- + F.cross_entropy(logits_per_text, labels)
295
- ) / 2
296
- else:
297
- audio_weight = (all_audio_features @ all_audio_features.T).detach()
298
- audio_weight = (
299
- torch.exp(
300
- torch.sum(audio_weight, axis=1)
301
- / (self.weight_loss_kappa * len(all_audio_features))
302
- )
303
- ).detach()
304
- text_weight = (all_text_features @ all_text_features.T).detach()
305
- text_weight = (
306
- torch.exp(
307
- torch.sum(text_weight, axis=1)
308
- / (self.weight_loss_kappa * len(all_text_features))
309
- )
310
- ).detach()
311
- total_loss = (
312
- F.cross_entropy(logits_per_audio, labels, weight=text_weight)
313
- + F.cross_entropy(logits_per_text, labels, weight=audio_weight)
314
- ) / 2
315
- return total_loss
316
-
317
-
318
- def lp_gather_features(pred, target, world_size=1, use_horovod=False):
319
- if use_horovod:
320
- assert hvd is not None, "Please install horovod"
321
- with torch.no_grad():
322
- all_preds = hvd.allgather(pred)
323
- all_targets = hvd.allgath(target)
324
- else:
325
- gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)]
326
- gathered_targets = [torch.zeros_like(target) for _ in range(world_size)]
327
-
328
- dist.all_gather(gathered_preds, pred)
329
- dist.all_gather(gathered_targets, target)
330
- all_preds = torch.cat(gathered_preds, dim=0)
331
- all_targets = torch.cat(gathered_targets, dim=0)
332
-
333
- return all_preds, all_targets
334
-
335
-
336
- def get_map(pred, target):
337
- pred = torch.sigmoid(pred).numpy()
338
- target = target.numpy()
339
- return np.mean(average_precision_score(target, pred, average=None))
340
-
341
-
342
- def get_acc(pred, target):
343
- pred = torch.argmax(pred, 1).numpy()
344
- target = torch.argmax(target, 1).numpy()
345
- return accuracy_score(target, pred)
346
-
347
-
348
- def get_mauc(pred, target):
349
- pred = torch.sigmoid(pred).numpy()
350
- target = target.numpy()
351
- return np.mean(roc_auc_score(target, pred, average=None))
352
-
353
-
354
- class LPMetrics(object):
355
- def __init__(self, metric_names=["map", "acc", "mauc"]):
356
- self.metrics = []
357
- for name in metric_names:
358
- self.metrics.append(self.get_metric(name))
359
- self.metric_names = metric_names
360
-
361
- def get_metric(self, name):
362
- if name == "map":
363
- return get_map
364
- elif name == "acc":
365
- return get_acc
366
- elif name == "mauc":
367
- return get_mauc
368
- else:
369
- raise ValueError(f"the metric should be at least one of [map, acc, mauc]")
370
-
371
- def evaluate_mertics(self, pred, target):
372
- metric_dict = {}
373
- for i in range(len(self.metric_names)):
374
- metric_dict[self.metric_names[i]] = self.metrics[i](pred, target)
375
- return metric_dict
376
-
377
-
378
- def calc_celoss(pred, target):
379
- target = torch.argmax(target, 1).long()
380
- return nn.CrossEntropyLoss()(pred, target)
381
-
382
-
383
- class LPLoss(nn.Module):
384
- def __init__(self, loss_name):
385
- super().__init__()
386
- if loss_name == "bce":
387
- self.loss_func = nn.BCEWithLogitsLoss()
388
- elif loss_name == "ce":
389
- self.loss_func = calc_celoss
390
- elif loss_name == "mse":
391
- self.loss_func = nn.MSELoss()
392
- else:
393
- raise ValueError(f"the loss func should be at least one of [bce, ce, mse]")
394
-
395
- def forward(self, pred, target):
396
- loss = self.loss_func(pred, target)
397
- return loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model.py DELETED
@@ -1,931 +0,0 @@
1
- """ CLAP Model
2
-
3
- Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- Adapted to the Audio Task.
5
- """
6
-
7
- from collections import OrderedDict
8
- from dataclasses import dataclass
9
- from typing import Tuple, Union, Callable, Optional
10
-
11
- import numpy as np
12
- import torch
13
- import torch.nn.functional as F
14
- from torch import nn
15
-
16
- import logging
17
- from .utils import freeze_batch_norm_2d
18
-
19
- from .pann_model import create_pann_model
20
- from .htsat import create_htsat_model
21
- from transformers import BertModel, RobertaModel, BartModel, RobertaConfig
22
-
23
-
24
- class MLPLayers(nn.Module):
25
- def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1):
26
- super(MLPLayers, self).__init__()
27
- self.nonlin = nonlin
28
- self.dropout = dropout
29
-
30
- sequence = []
31
- for u0, u1 in zip(units[:-1], units[1:]):
32
- sequence.append(nn.Linear(u0, u1))
33
- sequence.append(self.nonlin)
34
- sequence.append(nn.Dropout(self.dropout))
35
- sequence = sequence[:-2]
36
-
37
- self.sequential = nn.Sequential(*sequence)
38
-
39
- def forward(self, X):
40
- X = self.sequential(X)
41
- return X
42
-
43
-
44
- class Bottleneck(nn.Module):
45
- expansion = 4
46
-
47
- def __init__(self, inplanes, planes, stride=1):
48
- super().__init__()
49
-
50
- # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
51
- self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
52
- self.bn1 = nn.BatchNorm2d(planes)
53
-
54
- self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
55
- self.bn2 = nn.BatchNorm2d(planes)
56
-
57
- self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
58
-
59
- self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
60
- self.bn3 = nn.BatchNorm2d(planes * self.expansion)
61
-
62
- self.relu = nn.ReLU(inplace=True)
63
- self.downsample = None
64
- self.stride = stride
65
-
66
- if stride > 1 or inplanes != planes * Bottleneck.expansion:
67
- # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
68
- self.downsample = nn.Sequential(
69
- OrderedDict(
70
- [
71
- ("-1", nn.AvgPool2d(stride)),
72
- (
73
- "0",
74
- nn.Conv2d(
75
- inplanes,
76
- planes * self.expansion,
77
- 1,
78
- stride=1,
79
- bias=False,
80
- ),
81
- ),
82
- ("1", nn.BatchNorm2d(planes * self.expansion)),
83
- ]
84
- )
85
- )
86
-
87
- def forward(self, x: torch.Tensor):
88
- identity = x
89
-
90
- out = self.relu(self.bn1(self.conv1(x)))
91
- out = self.relu(self.bn2(self.conv2(out)))
92
- out = self.avgpool(out)
93
- out = self.bn3(self.conv3(out))
94
-
95
- if self.downsample is not None:
96
- identity = self.downsample(x)
97
-
98
- out += identity
99
- out = self.relu(out)
100
- return out
101
-
102
-
103
- class AttentionPool2d(nn.Module):
104
- def __init__(
105
- self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None
106
- ):
107
- super().__init__()
108
- self.positional_embedding = nn.Parameter(
109
- torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5
110
- )
111
- self.k_proj = nn.Linear(embed_dim, embed_dim)
112
- self.q_proj = nn.Linear(embed_dim, embed_dim)
113
- self.v_proj = nn.Linear(embed_dim, embed_dim)
114
- self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
115
- self.num_heads = num_heads
116
-
117
- def forward(self, x):
118
- x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(
119
- 2, 0, 1
120
- ) # NCHW -> (HW)NC
121
- x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
122
- x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
123
- x, _ = F.multi_head_attention_forward(
124
- query=x,
125
- key=x,
126
- value=x,
127
- embed_dim_to_check=x.shape[-1],
128
- num_heads=self.num_heads,
129
- q_proj_weight=self.q_proj.weight,
130
- k_proj_weight=self.k_proj.weight,
131
- v_proj_weight=self.v_proj.weight,
132
- in_proj_weight=None,
133
- in_proj_bias=torch.cat(
134
- [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]
135
- ),
136
- bias_k=None,
137
- bias_v=None,
138
- add_zero_attn=False,
139
- dropout_p=0,
140
- out_proj_weight=self.c_proj.weight,
141
- out_proj_bias=self.c_proj.bias,
142
- use_separate_proj_weight=True,
143
- training=self.training,
144
- need_weights=False,
145
- )
146
-
147
- return x[0]
148
-
149
-
150
- class ModifiedResNet(nn.Module):
151
- """
152
- A ResNet class that is similar to torchvision's but contains the following changes:
153
- - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
154
- - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
155
- - The final pooling layer is a QKV attention instead of an average pool
156
- """
157
-
158
- def __init__(self, layers, output_dim, heads, image_size=224, width=64):
159
- super().__init__()
160
- self.output_dim = output_dim
161
- self.image_size = image_size
162
-
163
- # the 3-layer stem
164
- self.conv1 = nn.Conv2d(
165
- 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False
166
- )
167
- self.bn1 = nn.BatchNorm2d(width // 2)
168
- self.conv2 = nn.Conv2d(
169
- width // 2, width // 2, kernel_size=3, padding=1, bias=False
170
- )
171
- self.bn2 = nn.BatchNorm2d(width // 2)
172
- self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
173
- self.bn3 = nn.BatchNorm2d(width)
174
- self.avgpool = nn.AvgPool2d(2)
175
- self.relu = nn.ReLU(inplace=True)
176
-
177
- # residual layers
178
- self._inplanes = width # this is a *mutable* variable used during construction
179
- self.layer1 = self._make_layer(width, layers[0])
180
- self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
181
- self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
182
- self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
183
-
184
- embed_dim = width * 32 # the ResNet feature dimension
185
- self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim)
186
-
187
- self.init_parameters()
188
-
189
- def _make_layer(self, planes, blocks, stride=1):
190
- layers = [Bottleneck(self._inplanes, planes, stride)]
191
-
192
- self._inplanes = planes * Bottleneck.expansion
193
- for _ in range(1, blocks):
194
- layers.append(Bottleneck(self._inplanes, planes))
195
-
196
- return nn.Sequential(*layers)
197
-
198
- def init_parameters(self):
199
- if self.attnpool is not None:
200
- std = self.attnpool.c_proj.in_features**-0.5
201
- nn.init.normal_(self.attnpool.q_proj.weight, std=std)
202
- nn.init.normal_(self.attnpool.k_proj.weight, std=std)
203
- nn.init.normal_(self.attnpool.v_proj.weight, std=std)
204
- nn.init.normal_(self.attnpool.c_proj.weight, std=std)
205
-
206
- for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]:
207
- for name, param in resnet_block.named_parameters():
208
- if name.endswith("bn3.weight"):
209
- nn.init.zeros_(param)
210
-
211
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
212
- assert (
213
- unlocked_groups == 0
214
- ), "partial locking not currently supported for this model"
215
- for param in self.parameters():
216
- param.requires_grad = False
217
- if freeze_bn_stats:
218
- freeze_batch_norm_2d(self)
219
-
220
- def stem(self, x):
221
- for conv, bn in [
222
- (self.conv1, self.bn1),
223
- (self.conv2, self.bn2),
224
- (self.conv3, self.bn3),
225
- ]:
226
- x = self.relu(bn(conv(x)))
227
- x = self.avgpool(x)
228
- return x
229
-
230
- def forward(self, x):
231
- x = self.stem(x)
232
- x = self.layer1(x)
233
- x = self.layer2(x)
234
- x = self.layer3(x)
235
- x = self.layer4(x)
236
- x = self.attnpool(x)
237
-
238
- return x
239
-
240
-
241
- class LayerNorm(nn.LayerNorm):
242
- """Subclass torch's LayerNorm to handle fp16."""
243
-
244
- def forward(self, x: torch.Tensor):
245
- orig_type = x.dtype
246
- x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
247
- return x.to(orig_type)
248
-
249
-
250
- class QuickGELU(nn.Module):
251
- # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory
252
- def forward(self, x: torch.Tensor):
253
- return x * torch.sigmoid(1.702 * x)
254
-
255
-
256
- class ResidualAttentionBlock(nn.Module):
257
- def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU):
258
- super().__init__()
259
-
260
- self.attn = nn.MultiheadAttention(d_model, n_head)
261
- self.ln_1 = LayerNorm(d_model)
262
- self.mlp = nn.Sequential(
263
- OrderedDict(
264
- [
265
- ("c_fc", nn.Linear(d_model, d_model * 4)),
266
- ("gelu", act_layer()),
267
- ("c_proj", nn.Linear(d_model * 4, d_model)),
268
- ]
269
- )
270
- )
271
- self.ln_2 = LayerNorm(d_model)
272
-
273
- def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
274
- return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
275
-
276
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
277
- x = x + self.attention(self.ln_1(x), attn_mask=attn_mask)
278
- x = x + self.mlp(self.ln_2(x))
279
- return x
280
-
281
-
282
- class Transformer(nn.Module):
283
- def __init__(
284
- self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU
285
- ):
286
- super().__init__()
287
- self.width = width
288
- self.layers = layers
289
- self.resblocks = nn.ModuleList(
290
- [
291
- ResidualAttentionBlock(width, heads, act_layer=act_layer)
292
- for _ in range(layers)
293
- ]
294
- )
295
-
296
- def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
297
- for r in self.resblocks:
298
- x = r(x, attn_mask=attn_mask)
299
- return x
300
-
301
-
302
- class VisualTransformer(nn.Module):
303
- def __init__(
304
- self,
305
- image_size: int,
306
- patch_size: int,
307
- width: int,
308
- layers: int,
309
- heads: int,
310
- output_dim: int,
311
- act_layer: Callable = nn.GELU,
312
- ):
313
- super().__init__()
314
- self.image_size = image_size
315
- self.output_dim = output_dim
316
- self.conv1 = nn.Conv2d(
317
- in_channels=3,
318
- out_channels=width,
319
- kernel_size=patch_size,
320
- stride=patch_size,
321
- bias=False,
322
- )
323
-
324
- scale = width**-0.5
325
- self.class_embedding = nn.Parameter(scale * torch.randn(width))
326
- self.positional_embedding = nn.Parameter(
327
- scale * torch.randn((image_size // patch_size) ** 2 + 1, width)
328
- )
329
- self.ln_pre = LayerNorm(width)
330
-
331
- self.text_branch = Transformer(width, layers, heads, act_layer=act_layer)
332
-
333
- self.ln_post = LayerNorm(width)
334
- self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
335
-
336
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
337
- assert (
338
- unlocked_groups == 0
339
- ), "partial locking not currently supported for this model"
340
- for param in self.parameters():
341
- param.requires_grad = False
342
-
343
- def forward(self, x: torch.Tensor):
344
- x = self.conv1(x) # shape = [*, width, grid, grid]
345
- x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
346
- x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
347
- x = torch.cat(
348
- [
349
- self.class_embedding.to(x.dtype)
350
- + torch.zeros(
351
- x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device
352
- ),
353
- x,
354
- ],
355
- dim=1,
356
- ) # shape = [*, grid ** 2 + 1, width]
357
- x = x + self.positional_embedding.to(x.dtype)
358
- x = self.ln_pre(x)
359
-
360
- x = x.permute(1, 0, 2) # NLD -> LND
361
- x = self.text_branch(x)
362
- x = x.permute(1, 0, 2) # LND -> NLD
363
-
364
- x = self.ln_post(x[:, 0, :])
365
-
366
- if self.proj is not None:
367
- x = x @ self.proj
368
-
369
- return x
370
-
371
-
372
- @dataclass
373
- class CLAPVisionCfg:
374
- layers: Union[Tuple[int, int, int, int], int] = 12
375
- width: int = 768
376
- patch_size: int = 16
377
- image_size: Union[Tuple[int, int], int] = 224
378
- timm_model_name: str = (
379
- None # a valid model name overrides layers, width, patch_size
380
- )
381
- timm_model_pretrained: bool = (
382
- False # use (imagenet) pretrained weights for named model
383
- )
384
- timm_pool: str = (
385
- "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '')
386
- )
387
- timm_proj: str = (
388
- "linear" # linear projection for timm model output ('linear', 'mlp', '')
389
- )
390
-
391
-
392
- # Audio Config Class
393
- @dataclass
394
- class CLAPAudioCfp:
395
- model_type: str = "PANN"
396
- model_name: str = "Cnn14"
397
- sample_rate: int = 48000
398
- # Param
399
- audio_length: int = 1024
400
- window_size: int = 1024
401
- hop_size: int = 1024
402
- fmin: int = 50
403
- fmax: int = 14000
404
- class_num: int = 527
405
- mel_bins: int = 64
406
- clip_samples: int = 480000
407
-
408
-
409
- @dataclass
410
- class CLAPTextCfg:
411
- context_length: int
412
- vocab_size: int
413
- width: int
414
- heads: int
415
- layers: int
416
- model_type: str
417
-
418
-
419
- class CLAP(nn.Module):
420
- def __init__(
421
- self,
422
- embed_dim: int,
423
- audio_cfg: CLAPAudioCfp,
424
- text_cfg: CLAPTextCfg,
425
- quick_gelu: bool = False,
426
- enable_fusion: bool = False,
427
- fusion_type: str = "None",
428
- joint_embed_shape: int = 512,
429
- mlp_act: str = "relu",
430
- ):
431
- super().__init__()
432
- if isinstance(audio_cfg, dict):
433
- audio_cfg = CLAPAudioCfp(**audio_cfg)
434
- if isinstance(text_cfg, dict):
435
- text_cfg = CLAPTextCfg(**text_cfg)
436
-
437
- self.audio_cfg = audio_cfg
438
- self.text_cfg = text_cfg
439
- self.enable_fusion = enable_fusion
440
- self.fusion_type = fusion_type
441
- self.joint_embed_shape = joint_embed_shape
442
- self.mlp_act = mlp_act
443
-
444
- self.context_length = text_cfg.context_length
445
-
446
- # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more
447
- # memory efficient in recent PyTorch releases (>= 1.10).
448
- # NOTE: timm models always use native GELU regardless of quick_gelu flag.
449
- act_layer = QuickGELU if quick_gelu else nn.GELU
450
-
451
- if mlp_act == "relu":
452
- mlp_act_layer = nn.ReLU()
453
- elif mlp_act == "gelu":
454
- mlp_act_layer = nn.GELU()
455
- else:
456
- raise NotImplementedError
457
-
458
- # audio branch
459
- # audio branch parameters
460
- if audio_cfg.model_type == "PANN":
461
- self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type)
462
- elif audio_cfg.model_type == "HTSAT":
463
- self.audio_branch = create_htsat_model(
464
- audio_cfg, enable_fusion, fusion_type
465
- )
466
- else:
467
- logging.error(f"Model config for {audio_cfg.model_type} not found")
468
- raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.")
469
-
470
- # text branch
471
- # text branch parameters
472
- if text_cfg.model_type == "transformer":
473
- self.text_branch = Transformer(
474
- width=text_cfg.width,
475
- layers=text_cfg.layers,
476
- heads=text_cfg.heads,
477
- act_layer=act_layer,
478
- )
479
- self.vocab_size = text_cfg.vocab_size
480
- self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width)
481
- self.positional_embedding = nn.Parameter(
482
- torch.empty(self.context_length, text_cfg.width)
483
- )
484
- self.ln_final = LayerNorm(text_cfg.width)
485
- self.text_transform = MLPLayers(
486
- units=[
487
- self.joint_embed_shape,
488
- self.joint_embed_shape,
489
- self.joint_embed_shape,
490
- ],
491
- dropout=0.1,
492
- )
493
- self.text_projection = nn.Sequential(
494
- nn.Linear(text_cfg.width, self.joint_embed_shape),
495
- mlp_act_layer,
496
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
497
- )
498
- elif text_cfg.model_type == "bert":
499
- self.text_branch = BertModel.from_pretrained("bert-base-uncased")
500
- self.text_transform = MLPLayers(
501
- units=[
502
- self.joint_embed_shape,
503
- self.joint_embed_shape,
504
- self.joint_embed_shape,
505
- ],
506
- dropout=0.1,
507
- )
508
- self.text_projection = nn.Sequential(
509
- nn.Linear(768, self.joint_embed_shape),
510
- mlp_act_layer,
511
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
512
- )
513
- elif text_cfg.model_type == "roberta":
514
- self.text_branch = RobertaModel(
515
- RobertaConfig.from_pretrained("roberta-base")
516
- )
517
- self.text_transform = MLPLayers(
518
- units=[
519
- self.joint_embed_shape,
520
- self.joint_embed_shape,
521
- self.joint_embed_shape,
522
- ],
523
- dropout=0.1,
524
- )
525
- self.text_projection = nn.Sequential(
526
- nn.Linear(768, self.joint_embed_shape),
527
- mlp_act_layer,
528
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
529
- )
530
- elif text_cfg.model_type == "bart":
531
- self.text_branch = BartModel.from_pretrained("facebook/bart-base")
532
- self.text_transform = MLPLayers(
533
- units=[
534
- self.joint_embed_shape,
535
- self.joint_embed_shape,
536
- self.joint_embed_shape,
537
- ],
538
- dropout=0.1,
539
- )
540
- self.text_projection = nn.Sequential(
541
- nn.Linear(768, self.joint_embed_shape),
542
- mlp_act_layer,
543
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
544
- )
545
- else:
546
- logging.error(f"Model config for {text_cfg.model_type} not found")
547
- raise RuntimeError(f"Model config for {text_cfg.model_type} not found.")
548
- self.text_branch_type = text_cfg.model_type
549
- # text branch parameters
550
-
551
- # audio branch parameters
552
- self.audio_transform = MLPLayers(
553
- units=[
554
- self.joint_embed_shape,
555
- self.joint_embed_shape,
556
- self.joint_embed_shape,
557
- ],
558
- dropout=0.1,
559
- )
560
-
561
- # below here is text branch parameters
562
-
563
- # ============================================================================================================
564
- self.audio_projection = nn.Sequential(
565
- nn.Linear(embed_dim, self.joint_embed_shape),
566
- mlp_act_layer,
567
- nn.Linear(self.joint_embed_shape, self.joint_embed_shape),
568
- )
569
-
570
- self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
571
- self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
572
- self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False)
573
-
574
- self.init_text_branch_parameters()
575
-
576
- def init_text_branch_parameters(self):
577
- if self.text_branch_type == "transformer":
578
- nn.init.normal_(self.token_embedding.weight, std=0.02)
579
- nn.init.normal_(self.positional_embedding, std=0.01)
580
- proj_std = (self.text_branch.width**-0.5) * (
581
- (2 * self.text_branch.layers) ** -0.5
582
- )
583
- attn_std = self.text_branch.width**-0.5
584
- fc_std = (2 * self.text_branch.width) ** -0.5
585
- for block in self.text_branch.resblocks:
586
- nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
587
- nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
588
- nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
589
- nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
590
- if self.text_branch_type == "bert" or self.text_branch_type == "roberta":
591
- self.text_branch.embeddings.word_embeddings.weight.shape[-1]
592
- elif self.text_branch_type == "bart":
593
- self.text_branch.shared.weight.shape[-1]
594
- else:
595
- self.text_branch.width
596
- nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07))
597
- nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07))
598
-
599
- # deprecated
600
- # if hasattr(self.visual, 'init_parameters'):
601
- # self.visual.init_parameters()
602
-
603
- # if self.text_projection is not None:
604
- # nn.init.normal_(self.text_projection, std=width**-0.5)
605
-
606
- def build_attention_mask(self):
607
- # lazily create causal attention mask, with full attention between the vision tokens
608
- # pytorch uses additive attention mask; fill with -inf
609
- mask = torch.empty(self.context_length, self.context_length)
610
- mask.fill_(float("-inf"))
611
- mask.triu_(1) # zero out the lower diagonal
612
- return mask
613
-
614
- def encode_audio(self, audio, device):
615
- return self.audio_branch(
616
- audio, mixup_lambda=None, device=device
617
- ) # mix lambda needs to add
618
-
619
- # def list_of_dict_of_tensor2dict_of_tensor(self, x, device):
620
- # tmp = {}
621
- # for k in x[0].keys():
622
- # tmp[k] = []
623
- # for i in range(len(x)):
624
- # tmp[k].append(x[i][k][:77])
625
- # for k in x[0].keys():
626
- # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True)
627
- # return tmp
628
-
629
- def encode_text(self, text, device):
630
- if self.text_branch_type == "transformer":
631
- text = text.to(device=device, non_blocking=True)
632
- x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
633
-
634
- x = x + self.positional_embedding
635
- x = x.permute(1, 0, 2) # NLD -> LND
636
- x = self.text_branch(x, attn_mask=self.attn_mask)
637
- x = x.permute(1, 0, 2) # LND -> NLD
638
- x = self.ln_final(x)
639
-
640
- # x.shape = [batch_size, n_ctx, transformer.width]
641
- # take features from the eot embedding (eot_token is the highest number in each sequence)
642
- x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)])
643
- elif self.text_branch_type == "bert":
644
- # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device)
645
- # text = BatchEncoding(text)
646
- x = self.text_branch(
647
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
648
- attention_mask=text["attention_mask"].to(
649
- device=device, non_blocking=True
650
- ),
651
- token_type_ids=text["token_type_ids"].to(
652
- device=device, non_blocking=True
653
- ),
654
- )["pooler_output"]
655
- x = self.text_projection(x)
656
- elif self.text_branch_type == "roberta":
657
- x = self.text_branch(
658
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
659
- attention_mask=text["attention_mask"].to(
660
- device=device, non_blocking=True
661
- ),
662
- )["pooler_output"]
663
- x = self.text_projection(x)
664
- elif self.text_branch_type == "bart":
665
- x = torch.mean(
666
- self.text_branch(
667
- input_ids=text["input_ids"].to(device=device, non_blocking=True),
668
- attention_mask=text["attention_mask"].to(
669
- device=device, non_blocking=True
670
- ),
671
- )["encoder_last_hidden_state"],
672
- axis=1,
673
- )
674
- x = self.text_projection(x)
675
- else:
676
- logging.error(f"Model type {self.text_branch_type} not found")
677
- raise RuntimeError(f"Model type {self.text_branch_type} not found.")
678
- return x
679
-
680
- def forward(self, audio, text, device=None):
681
- """Forward audio and text into the CLAP
682
-
683
- Parameters
684
- ----------
685
- audio: torch.Tensor (batch_size, audio_length)
686
- the time-domain audio input / the batch of mel_spec and longer list.
687
- text: torch.Tensor () // need to add
688
- the text token input
689
- """
690
- if device is None:
691
- if audio is not None:
692
- device = audio.device
693
- elif text is not None:
694
- device = text.device
695
- if audio is None and text is None:
696
- # a hack to get the logit scale
697
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
698
- elif audio is None:
699
- return self.encode_text(text, device=device)
700
- elif text is None:
701
- return self.audio_projection(
702
- self.encode_audio(audio, device=device)["embedding"]
703
- )
704
- audio_features = self.audio_projection(
705
- self.encode_audio(audio, device=device)["embedding"]
706
- )
707
- audio_features = F.normalize(audio_features, dim=-1)
708
-
709
- text_features = self.encode_text(text, device=device)
710
- # print("text_features", text_features)
711
- # print("text_features.shape", text_features.shape)
712
- # print("text_features.type", type(text_features))
713
- text_features = F.normalize(text_features, dim=-1)
714
-
715
- audio_features_mlp = self.audio_transform(audio_features)
716
- text_features_mlp = self.text_transform(text_features)
717
- # Four outputs: audio features (basic & MLP), text features (basic & MLP)
718
- return (
719
- audio_features,
720
- text_features,
721
- audio_features_mlp,
722
- text_features_mlp,
723
- self.logit_scale_a.exp(),
724
- self.logit_scale_t.exp(),
725
- )
726
-
727
- def get_logit_scale(self):
728
- return self.logit_scale_a.exp(), self.logit_scale_t.exp()
729
-
730
- def get_text_embedding(self, data):
731
- """Get the text embedding from the model
732
-
733
- Parameters
734
- ----------
735
- data: torch.Tensor
736
- a tensor of text embedding
737
-
738
- Returns
739
- ----------
740
- text_embed: torch.Tensor
741
- a tensor of text_embeds (N, D)
742
-
743
- """
744
- device = next(self.parameters()).device
745
- for k in data:
746
- data[k] = data[k].to(device)
747
- text_embeds = self.encode_text(data, device=device)
748
- text_embeds = F.normalize(text_embeds, dim=-1)
749
-
750
- return text_embeds
751
-
752
- def get_audio_embedding(self, data):
753
- """Get the audio embedding from the model
754
-
755
- Parameters
756
- ----------
757
- data: a list of dict
758
- the audio input dict list from 'get_audio_feature' method
759
-
760
- Returns
761
- ----------
762
- audio_embed: torch.Tensor
763
- a tensor of audio_embeds (N, D)
764
-
765
- """
766
- device = next(self.parameters()).device
767
- # input_dict = {}
768
- # keys = data[0].keys()
769
- # for k in keys:
770
- # input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to(
771
- # device
772
- # )
773
- audio_embeds = self.audio_projection(
774
- self.encode_audio(data, device=device)["embedding"]
775
- )
776
- audio_embeds = F.normalize(audio_embeds, dim=-1)
777
-
778
- return audio_embeds
779
-
780
- def audio_infer(self, audio, hopsize=None, device=None):
781
- """Forward one audio and produce the audio embedding
782
-
783
- Parameters
784
- ----------
785
- audio: (audio_length)
786
- the time-domain audio input, notice that it must be only one input
787
- hopsize: int
788
- the overlap hopsize as the sliding window
789
-
790
- Returns
791
- ----------
792
- output_dict: {
793
- key: [n, (embedding_shape)] if "HTS-AT"
794
- or
795
- key: [(embedding_shape)] if "PANN"
796
- }
797
- the list of key values of the audio branch
798
-
799
- """
800
-
801
- assert not self.training, "the inference mode must be run at eval stage"
802
- output_dict = {}
803
- # PANN
804
- if self.audio_cfg.model_type == "PANN":
805
- audio_input = audio.unsqueeze(dim=0)
806
- output_dict[key] = self.encode_audio(audio_input, device=device)[
807
- key
808
- ].squeeze(dim=0)
809
- elif self.audio_cfg.model_type == "HTSAT":
810
- # repeat
811
- audio_len = len(audio)
812
- k = self.audio_cfg.clip_samples // audio_len
813
- if k > 1:
814
- audio = audio.repeat(k)
815
- audio_len = len(audio)
816
-
817
- if hopsize is None:
818
- hopsize = min(hopsize, audio_len)
819
-
820
- if audio_len > self.audio_cfg.clip_samples:
821
- audio_input = [
822
- audio[pos : pos + self.audio_cfg.clip_samples].clone()
823
- for pos in range(
824
- 0, audio_len - self.audio_cfg.clip_samples, hopsize
825
- )
826
- ]
827
- audio_input.append(audio[-self.audio_cfg.clip_samples :].clone())
828
- audio_input = torch.stack(audio_input)
829
- output_dict[key] = self.encode_audio(audio_input, device=device)[key]
830
- else:
831
- audio_input = audio.unsqueeze(dim=0)
832
- output_dict[key] = self.encode_audio(audio_input, device=device)[
833
- key
834
- ].squeeze(dim=0)
835
-
836
- return output_dict
837
-
838
-
839
- def convert_weights_to_fp16(model: nn.Module):
840
- """Convert applicable model parameters to fp16"""
841
-
842
- def _convert_weights_to_fp16(l):
843
- if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
844
- l.weight.data = l.weight.data.half()
845
- if l.bias is not None:
846
- l.bias.data = l.bias.data.half()
847
-
848
- if isinstance(l, nn.MultiheadAttention):
849
- for attr in [
850
- *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]],
851
- "in_proj_bias",
852
- "bias_k",
853
- "bias_v",
854
- ]:
855
- tensor = getattr(l, attr)
856
- if tensor is not None:
857
- tensor.data = tensor.data.half()
858
-
859
- for name in ["text_projection", "proj"]:
860
- if hasattr(l, name):
861
- attr = getattr(l, name)
862
- if attr is not None:
863
- attr.data = attr.data.half()
864
-
865
- model.apply(_convert_weights_to_fp16)
866
-
867
-
868
- # Ignore the state dict of the vision part
869
- def build_model_from_openai_state_dict(
870
- state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None"
871
- ):
872
- embed_dim = model_cfg["embed_dim"]
873
- audio_cfg = model_cfg["audio_cfg"]
874
- text_cfg = model_cfg["text_cfg"]
875
- state_dict["positional_embedding"].shape[0]
876
- state_dict["token_embedding.weight"].shape[0]
877
- transformer_width = state_dict["ln_final.weight"].shape[0]
878
- transformer_width // 64
879
- transformer_layers = len(
880
- set(
881
- k.split(".")[2]
882
- for k in state_dict
883
- if k.startswith(f"transformer.resblocks")
884
- )
885
- )
886
-
887
- audio_cfg = CLAPAudioCfp(**audio_cfg)
888
- text_cfg = CLAPTextCfg(**text_cfg)
889
-
890
- model = CLAP(
891
- embed_dim,
892
- audio_cfg=audio_cfg,
893
- text_cfg=text_cfg,
894
- quick_gelu=True, # OpenAI models were trained with QuickGELU
895
- enable_fusion=enable_fusion,
896
- fusion_type=fusion_type,
897
- )
898
- state_dict["logit_scale_a"] = state_dict["logit_scale"]
899
- state_dict["logit_scale_t"] = state_dict["logit_scale"]
900
- pop_keys = list(state_dict.keys())[::]
901
- # pop the visual branch saved weights
902
- for key in pop_keys:
903
- if key.startswith("visual."):
904
- state_dict.pop(key, None)
905
-
906
- for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]:
907
- state_dict.pop(key, None)
908
-
909
- # not use fp16
910
- # convert_weights_to_fp16(model)
911
- model.load_state_dict(state_dict, strict=False)
912
- return model.eval()
913
-
914
-
915
- def trace_model(model, batch_size=256, device=torch.device("cpu")):
916
- model.eval()
917
- audio_length = model.audio_cfg.audio_length
918
- example_audio = torch.ones((batch_size, audio_length), device=device)
919
- example_text = torch.zeros(
920
- (batch_size, model.context_length), dtype=torch.int, device=device
921
- )
922
- model = torch.jit.trace_module(
923
- model,
924
- inputs=dict(
925
- forward=(example_audio, example_text),
926
- encode_text=(example_text,),
927
- encode_image=(example_audio,),
928
- ),
929
- )
930
- model.audio_cfg.audio_length = audio_length # Question: what does this do?
931
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/HTSAT-base.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "base"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/HTSAT-large.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "large"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/HTSAT-tiny.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "HTSAT",
14
- "model_name": "tiny"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/PANN-10.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn10"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/PANN-14-fmax-18k.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 18000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 960000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 360,
10
- "fmin": 50,
11
- "fmax": 8000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/PANN-14-tiny-transformer.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 4
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/PANN-14-win-1536.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1536,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/PANN-14.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 2048,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn14"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/PANN-6.json DELETED
@@ -1,23 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "audio_cfg": {
4
- "audio_length": 1024,
5
- "clip_samples": 480000,
6
- "mel_bins": 64,
7
- "sample_rate": 48000,
8
- "window_size": 1024,
9
- "hop_size": 480,
10
- "fmin": 50,
11
- "fmax": 14000,
12
- "class_num": 527,
13
- "model_type": "PANN",
14
- "model_name": "Cnn6"
15
- },
16
- "text_cfg": {
17
- "context_length": 77,
18
- "vocab_size": 49408,
19
- "width": 512,
20
- "heads": 8,
21
- "layers": 12
22
- }
23
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/RN101-quickgelu.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 23,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/RN101.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 23,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/RN50-quickgelu.json DELETED
@@ -1,22 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": [
7
- 3,
8
- 4,
9
- 6,
10
- 3
11
- ],
12
- "width": 64,
13
- "patch_size": null
14
- },
15
- "text_cfg": {
16
- "context_length": 77,
17
- "vocab_size": 49408,
18
- "width": 512,
19
- "heads": 8,
20
- "layers": 12
21
- }
22
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/RN50.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 1024,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": [
6
- 3,
7
- 4,
8
- 6,
9
- 3
10
- ],
11
- "width": 64,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 512,
18
- "heads": 8,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/RN50x16.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 384,
5
- "layers": [
6
- 6,
7
- 8,
8
- 18,
9
- 8
10
- ],
11
- "width": 96,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 768,
18
- "heads": 12,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/RN50x4.json DELETED
@@ -1,21 +0,0 @@
1
- {
2
- "embed_dim": 640,
3
- "vision_cfg": {
4
- "image_size": 288,
5
- "layers": [
6
- 4,
7
- 6,
8
- 10,
9
- 6
10
- ],
11
- "width": 80,
12
- "patch_size": null
13
- },
14
- "text_cfg": {
15
- "context_length": 77,
16
- "vocab_size": 49408,
17
- "width": 640,
18
- "heads": 10,
19
- "layers": 12
20
- }
21
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/ViT-B-16.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 16
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 512,
13
- "heads": 8,
14
- "layers": 12
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/ViT-B-32-quickgelu.json DELETED
@@ -1,17 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "quick_gelu": true,
4
- "vision_cfg": {
5
- "image_size": 224,
6
- "layers": 12,
7
- "width": 768,
8
- "patch_size": 32
9
- },
10
- "text_cfg": {
11
- "context_length": 77,
12
- "vocab_size": 49408,
13
- "width": 512,
14
- "heads": 8,
15
- "layers": 12
16
- }
17
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/ViT-B-32.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "embed_dim": 512,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 12,
6
- "width": 768,
7
- "patch_size": 32
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 512,
13
- "heads": 8,
14
- "layers": 12
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/model_configs/ViT-L-14.json DELETED
@@ -1,16 +0,0 @@
1
- {
2
- "embed_dim": 768,
3
- "vision_cfg": {
4
- "image_size": 224,
5
- "layers": 24,
6
- "width": 1024,
7
- "patch_size": 14
8
- },
9
- "text_cfg": {
10
- "context_length": 77,
11
- "vocab_size": 49408,
12
- "width": 768,
13
- "heads": 12,
14
- "layers": 12
15
- }
16
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/openai.py DELETED
@@ -1,156 +0,0 @@
1
- """ OpenAI pretrained model functions
2
-
3
- Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
-
6
- import os
7
- import warnings
8
- from typing import Union, List
9
-
10
- import torch
11
-
12
- from .model import build_model_from_openai_state_dict
13
- from .pretrained import (
14
- get_pretrained_url,
15
- list_pretrained_tag_models,
16
- download_pretrained,
17
- )
18
-
19
- __all__ = ["list_openai_models", "load_openai_model"]
20
-
21
-
22
- def list_openai_models() -> List[str]:
23
- """Returns the names of available CLIP models"""
24
- return list_pretrained_tag_models("openai")
25
-
26
-
27
- def load_openai_model(
28
- name: str,
29
- model_cfg,
30
- device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu",
31
- jit=True,
32
- cache_dir=os.path.expanduser("~/.cache/clip"),
33
- enable_fusion: bool = False,
34
- fusion_type: str = "None",
35
- ):
36
- """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model
37
-
38
- Parameters
39
- ----------
40
- name : str
41
- A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
42
- device : Union[str, torch.device]
43
- The device to put the loaded model
44
- jit : bool
45
- Whether to load the optimized JIT model (default) or more hackable non-JIT model.
46
-
47
- Returns
48
- -------
49
- model : torch.nn.Module
50
- The CLAP model
51
- preprocess : Callable[[PIL.Image], torch.Tensor]
52
- A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
53
- """
54
- if get_pretrained_url(name, "openai"):
55
- model_path = download_pretrained(
56
- get_pretrained_url(name, "openai"), root=cache_dir
57
- )
58
- elif os.path.isfile(name):
59
- model_path = name
60
- else:
61
- raise RuntimeError(
62
- f"Model {name} not found; available models = {list_openai_models()}"
63
- )
64
-
65
- try:
66
- # loading JIT archive
67
- model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
68
- state_dict = None
69
- except RuntimeError:
70
- # loading saved state dict
71
- if jit:
72
- warnings.warn(
73
- f"File {model_path} is not a JIT archive. Loading as a state dict instead"
74
- )
75
- jit = False
76
- state_dict = torch.load(model_path, map_location="cpu")
77
-
78
- if not jit:
79
- try:
80
- model = build_model_from_openai_state_dict(
81
- state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type
82
- ).to(device)
83
- except KeyError:
84
- sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
85
- model = build_model_from_openai_state_dict(
86
- sd, model_cfg, enable_fusion, fusion_type
87
- ).to(device)
88
-
89
- if str(device) == "cpu":
90
- model.float()
91
- return model
92
-
93
- # patch the device names
94
- device_holder = torch.jit.trace(
95
- lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]
96
- )
97
- device_node = [
98
- n
99
- for n in device_holder.graph.findAllNodes("prim::Constant")
100
- if "Device" in repr(n)
101
- ][-1]
102
-
103
- def patch_device(module):
104
- try:
105
- graphs = [module.graph] if hasattr(module, "graph") else []
106
- except RuntimeError:
107
- graphs = []
108
-
109
- if hasattr(module, "forward1"):
110
- graphs.append(module.forward1.graph)
111
-
112
- for graph in graphs:
113
- for node in graph.findAllNodes("prim::Constant"):
114
- if "value" in node.attributeNames() and str(node["value"]).startswith(
115
- "cuda"
116
- ):
117
- node.copyAttributes(device_node)
118
-
119
- model.apply(patch_device)
120
- patch_device(model.encode_audio)
121
- patch_device(model.encode_text)
122
-
123
- # patch dtype to float32 on CPU
124
- if str(device) == "cpu":
125
- float_holder = torch.jit.trace(
126
- lambda: torch.ones([]).float(), example_inputs=[]
127
- )
128
- float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
129
- float_node = float_input.node()
130
-
131
- def patch_float(module):
132
- try:
133
- graphs = [module.graph] if hasattr(module, "graph") else []
134
- except RuntimeError:
135
- graphs = []
136
-
137
- if hasattr(module, "forward1"):
138
- graphs.append(module.forward1.graph)
139
-
140
- for graph in graphs:
141
- for node in graph.findAllNodes("aten::to"):
142
- inputs = list(node.inputs())
143
- for i in [
144
- 1,
145
- 2,
146
- ]: # dtype can be the second or third argument to aten::to()
147
- if inputs[i].node()["value"] == 5:
148
- inputs[i].node().copyAttributes(float_node)
149
-
150
- model.apply(patch_float)
151
- patch_float(model.encode_audio)
152
- patch_float(model.encode_text)
153
- model.float()
154
-
155
- model.audio_branch.audio_length = model.audio_cfg.audio_length
156
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/pann_model.py DELETED
@@ -1,697 +0,0 @@
1
- # PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition
2
- # Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn
3
- # Some layers are re-designed for CLAP
4
- import os
5
-
6
- os.environ["NUMBA_CACHE_DIR"] = "/tmp/"
7
-
8
- import torch
9
- import torch.nn as nn
10
- import torch.nn.functional as F
11
- from torchlibrosa.stft import Spectrogram, LogmelFilterBank
12
- from torchlibrosa.augmentation import SpecAugmentation
13
-
14
- from .utils import do_mixup, interpolate
15
- from .feature_fusion import iAFF, AFF, DAF
16
-
17
-
18
- def init_layer(layer):
19
- """Initialize a Linear or Convolutional layer."""
20
- nn.init.xavier_uniform_(layer.weight)
21
-
22
- if hasattr(layer, "bias"):
23
- if layer.bias is not None:
24
- layer.bias.data.fill_(0.0)
25
-
26
-
27
- def init_bn(bn):
28
- """Initialize a Batchnorm layer."""
29
- bn.bias.data.fill_(0.0)
30
- bn.weight.data.fill_(1.0)
31
-
32
-
33
- class ConvBlock(nn.Module):
34
- def __init__(self, in_channels, out_channels):
35
- super(ConvBlock, self).__init__()
36
-
37
- self.conv1 = nn.Conv2d(
38
- in_channels=in_channels,
39
- out_channels=out_channels,
40
- kernel_size=(3, 3),
41
- stride=(1, 1),
42
- padding=(1, 1),
43
- bias=False,
44
- )
45
-
46
- self.conv2 = nn.Conv2d(
47
- in_channels=out_channels,
48
- out_channels=out_channels,
49
- kernel_size=(3, 3),
50
- stride=(1, 1),
51
- padding=(1, 1),
52
- bias=False,
53
- )
54
-
55
- self.bn1 = nn.BatchNorm2d(out_channels)
56
- self.bn2 = nn.BatchNorm2d(out_channels)
57
-
58
- self.init_weight()
59
-
60
- def init_weight(self):
61
- init_layer(self.conv1)
62
- init_layer(self.conv2)
63
- init_bn(self.bn1)
64
- init_bn(self.bn2)
65
-
66
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
67
- x = input
68
- x = F.relu_(self.bn1(self.conv1(x)))
69
- x = F.relu_(self.bn2(self.conv2(x)))
70
- if pool_type == "max":
71
- x = F.max_pool2d(x, kernel_size=pool_size)
72
- elif pool_type == "avg":
73
- x = F.avg_pool2d(x, kernel_size=pool_size)
74
- elif pool_type == "avg+max":
75
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
76
- x2 = F.max_pool2d(x, kernel_size=pool_size)
77
- x = x1 + x2
78
- else:
79
- raise Exception("Incorrect argument!")
80
-
81
- return x
82
-
83
-
84
- class ConvBlock5x5(nn.Module):
85
- def __init__(self, in_channels, out_channels):
86
- super(ConvBlock5x5, self).__init__()
87
-
88
- self.conv1 = nn.Conv2d(
89
- in_channels=in_channels,
90
- out_channels=out_channels,
91
- kernel_size=(5, 5),
92
- stride=(1, 1),
93
- padding=(2, 2),
94
- bias=False,
95
- )
96
-
97
- self.bn1 = nn.BatchNorm2d(out_channels)
98
-
99
- self.init_weight()
100
-
101
- def init_weight(self):
102
- init_layer(self.conv1)
103
- init_bn(self.bn1)
104
-
105
- def forward(self, input, pool_size=(2, 2), pool_type="avg"):
106
- x = input
107
- x = F.relu_(self.bn1(self.conv1(x)))
108
- if pool_type == "max":
109
- x = F.max_pool2d(x, kernel_size=pool_size)
110
- elif pool_type == "avg":
111
- x = F.avg_pool2d(x, kernel_size=pool_size)
112
- elif pool_type == "avg+max":
113
- x1 = F.avg_pool2d(x, kernel_size=pool_size)
114
- x2 = F.max_pool2d(x, kernel_size=pool_size)
115
- x = x1 + x2
116
- else:
117
- raise Exception("Incorrect argument!")
118
-
119
- return x
120
-
121
-
122
- class AttBlock(nn.Module):
123
- def __init__(self, n_in, n_out, activation="linear", temperature=1.0):
124
- super(AttBlock, self).__init__()
125
-
126
- self.activation = activation
127
- self.temperature = temperature
128
- self.att = nn.Conv1d(
129
- in_channels=n_in,
130
- out_channels=n_out,
131
- kernel_size=1,
132
- stride=1,
133
- padding=0,
134
- bias=True,
135
- )
136
- self.cla = nn.Conv1d(
137
- in_channels=n_in,
138
- out_channels=n_out,
139
- kernel_size=1,
140
- stride=1,
141
- padding=0,
142
- bias=True,
143
- )
144
-
145
- self.bn_att = nn.BatchNorm1d(n_out)
146
- self.init_weights()
147
-
148
- def init_weights(self):
149
- init_layer(self.att)
150
- init_layer(self.cla)
151
- init_bn(self.bn_att)
152
-
153
- def forward(self, x):
154
- # x: (n_samples, n_in, n_time)
155
- norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
156
- cla = self.nonlinear_transform(self.cla(x))
157
- x = torch.sum(norm_att * cla, dim=2)
158
- return x, norm_att, cla
159
-
160
- def nonlinear_transform(self, x):
161
- if self.activation == "linear":
162
- return x
163
- elif self.activation == "sigmoid":
164
- return torch.sigmoid(x)
165
-
166
-
167
- class Cnn14(nn.Module):
168
- def __init__(
169
- self,
170
- sample_rate,
171
- window_size,
172
- hop_size,
173
- mel_bins,
174
- fmin,
175
- fmax,
176
- classes_num,
177
- enable_fusion=False,
178
- fusion_type="None",
179
- ):
180
- super(Cnn14, self).__init__()
181
-
182
- window = "hann"
183
- center = True
184
- pad_mode = "reflect"
185
- ref = 1.0
186
- amin = 1e-10
187
- top_db = None
188
-
189
- self.enable_fusion = enable_fusion
190
- self.fusion_type = fusion_type
191
-
192
- # Spectrogram extractor
193
- self.spectrogram_extractor = Spectrogram(
194
- n_fft=window_size,
195
- hop_length=hop_size,
196
- win_length=window_size,
197
- window=window,
198
- center=center,
199
- pad_mode=pad_mode,
200
- freeze_parameters=True,
201
- )
202
-
203
- # Logmel feature extractor
204
- self.logmel_extractor = LogmelFilterBank(
205
- sr=sample_rate,
206
- n_fft=window_size,
207
- n_mels=mel_bins,
208
- fmin=fmin,
209
- fmax=fmax,
210
- ref=ref,
211
- amin=amin,
212
- top_db=top_db,
213
- freeze_parameters=True,
214
- )
215
-
216
- # Spec augmenter
217
- self.spec_augmenter = SpecAugmentation(
218
- time_drop_width=64,
219
- time_stripes_num=2,
220
- freq_drop_width=8,
221
- freq_stripes_num=2,
222
- )
223
-
224
- self.bn0 = nn.BatchNorm2d(64)
225
-
226
- if (self.enable_fusion) and (self.fusion_type == "channel_map"):
227
- self.conv_block1 = ConvBlock(in_channels=4, out_channels=64)
228
- else:
229
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
230
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
231
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
232
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
233
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
234
- self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)
235
-
236
- self.fc1 = nn.Linear(2048, 2048, bias=True)
237
- self.fc_audioset = nn.Linear(2048, classes_num, bias=True)
238
-
239
- if (self.enable_fusion) and (
240
- self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]
241
- ):
242
- self.mel_conv1d = nn.Sequential(
243
- nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2),
244
- nn.BatchNorm1d(64), # No Relu
245
- )
246
- if self.fusion_type == "daf_1d":
247
- self.fusion_model = DAF()
248
- elif self.fusion_type == "aff_1d":
249
- self.fusion_model = AFF(channels=64, type="1D")
250
- elif self.fusion_type == "iaff_1d":
251
- self.fusion_model = iAFF(channels=64, type="1D")
252
-
253
- if (self.enable_fusion) and (
254
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
255
- ):
256
- self.mel_conv2d = nn.Sequential(
257
- nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)),
258
- nn.BatchNorm2d(64),
259
- nn.ReLU(inplace=True),
260
- )
261
-
262
- if self.fusion_type == "daf_2d":
263
- self.fusion_model = DAF()
264
- elif self.fusion_type == "aff_2d":
265
- self.fusion_model = AFF(channels=64, type="2D")
266
- elif self.fusion_type == "iaff_2d":
267
- self.fusion_model = iAFF(channels=64, type="2D")
268
- self.init_weight()
269
-
270
- def init_weight(self):
271
- init_bn(self.bn0)
272
- init_layer(self.fc1)
273
- init_layer(self.fc_audioset)
274
-
275
- def forward(self, input, mixup_lambda=None, device=None):
276
- """
277
- Input: (batch_size, data_length)"""
278
-
279
- if self.enable_fusion and input["longer"].sum() == 0:
280
- # if no audio is longer than 10s, then randomly select one audio to be longer
281
- input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True
282
-
283
- if not self.enable_fusion:
284
- x = self.spectrogram_extractor(
285
- input["waveform"].to(device=device, non_blocking=True)
286
- ) # (batch_size, 1, time_steps, freq_bins)
287
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
288
-
289
- x = x.transpose(1, 3)
290
- x = self.bn0(x)
291
- x = x.transpose(1, 3)
292
- else:
293
- longer_list = input["longer"].to(device=device, non_blocking=True)
294
- x = input["mel_fusion"].to(device=device, non_blocking=True)
295
- longer_list_idx = torch.where(longer_list)[0]
296
- x = x.transpose(1, 3)
297
- x = self.bn0(x)
298
- x = x.transpose(1, 3)
299
- if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]:
300
- new_x = x[:, 0:1, :, :].clone().contiguous()
301
- # local processing
302
- if len(longer_list_idx) > 0:
303
- fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous()
304
- FB, FC, FT, FF = fusion_x_local.size()
305
- fusion_x_local = fusion_x_local.view(FB * FC, FT, FF)
306
- fusion_x_local = torch.permute(
307
- fusion_x_local, (0, 2, 1)
308
- ).contiguous()
309
- fusion_x_local = self.mel_conv1d(fusion_x_local)
310
- fusion_x_local = fusion_x_local.view(
311
- FB, FC, FF, fusion_x_local.size(-1)
312
- )
313
- fusion_x_local = (
314
- torch.permute(fusion_x_local, (0, 2, 1, 3))
315
- .contiguous()
316
- .flatten(2)
317
- )
318
- if fusion_x_local.size(-1) < FT:
319
- fusion_x_local = torch.cat(
320
- [
321
- fusion_x_local,
322
- torch.zeros(
323
- (FB, FF, FT - fusion_x_local.size(-1)),
324
- device=device,
325
- ),
326
- ],
327
- dim=-1,
328
- )
329
- else:
330
- fusion_x_local = fusion_x_local[:, :, :FT]
331
- # 1D fusion
332
- new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous()
333
- new_x[longer_list_idx] = self.fusion_model(
334
- new_x[longer_list_idx], fusion_x_local
335
- )
336
- x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :]
337
- else:
338
- x = new_x
339
- elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]:
340
- x = x # no change
341
-
342
- if self.training:
343
- x = self.spec_augmenter(x)
344
- # Mixup on spectrogram
345
- if self.training and mixup_lambda is not None:
346
- x = do_mixup(x, mixup_lambda)
347
- if (self.enable_fusion) and (
348
- self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"]
349
- ):
350
- global_x = x[:, 0:1, :, :]
351
-
352
- # global processing
353
- B, C, H, W = global_x.shape
354
- global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg")
355
- if len(longer_list_idx) > 0:
356
- local_x = x[longer_list_idx, 1:, :, :].contiguous()
357
- TH = global_x.size(-2)
358
- # local processing
359
- B, C, H, W = local_x.shape
360
- local_x = local_x.view(B * C, 1, H, W)
361
- local_x = self.mel_conv2d(local_x)
362
- local_x = local_x.view(
363
- B, C, local_x.size(1), local_x.size(2), local_x.size(3)
364
- )
365
- local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3)
366
- TB, TC, _, TW = local_x.size()
367
- if local_x.size(-2) < TH:
368
- local_x = torch.cat(
369
- [
370
- local_x,
371
- torch.zeros(
372
- (TB, TC, TH - local_x.size(-2), TW),
373
- device=global_x.device,
374
- ),
375
- ],
376
- dim=-2,
377
- )
378
- else:
379
- local_x = local_x[:, :, :TH, :]
380
-
381
- global_x[longer_list_idx] = self.fusion_model(
382
- global_x[longer_list_idx], local_x
383
- )
384
- x = global_x
385
- else:
386
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
387
-
388
- x = F.dropout(x, p=0.2, training=self.training)
389
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
390
- x = F.dropout(x, p=0.2, training=self.training)
391
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
392
- x = F.dropout(x, p=0.2, training=self.training)
393
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
394
- x = F.dropout(x, p=0.2, training=self.training)
395
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
396
- x = F.dropout(x, p=0.2, training=self.training)
397
- x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
398
- x = F.dropout(x, p=0.2, training=self.training)
399
- x = torch.mean(x, dim=3)
400
-
401
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
402
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
403
- latent_x = latent_x1 + latent_x2
404
- latent_x = latent_x.transpose(1, 2)
405
- latent_x = F.relu_(self.fc1(latent_x))
406
- latent_output = interpolate(latent_x, 32)
407
-
408
- (x1, _) = torch.max(x, dim=2)
409
- x2 = torch.mean(x, dim=2)
410
- x = x1 + x2
411
- x = F.dropout(x, p=0.5, training=self.training)
412
- x = F.relu_(self.fc1(x))
413
- embedding = F.dropout(x, p=0.5, training=self.training)
414
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
415
-
416
- output_dict = {
417
- "clipwise_output": clipwise_output,
418
- "embedding": embedding,
419
- "fine_grained_embedding": latent_output,
420
- }
421
- return output_dict
422
-
423
-
424
- class Cnn6(nn.Module):
425
- def __init__(
426
- self,
427
- sample_rate,
428
- window_size,
429
- hop_size,
430
- mel_bins,
431
- fmin,
432
- fmax,
433
- classes_num,
434
- enable_fusion=False,
435
- fusion_type="None",
436
- ):
437
- super(Cnn6, self).__init__()
438
-
439
- window = "hann"
440
- center = True
441
- pad_mode = "reflect"
442
- ref = 1.0
443
- amin = 1e-10
444
- top_db = None
445
-
446
- self.enable_fusion = enable_fusion
447
- self.fusion_type = fusion_type
448
-
449
- # Spectrogram extractor
450
- self.spectrogram_extractor = Spectrogram(
451
- n_fft=window_size,
452
- hop_length=hop_size,
453
- win_length=window_size,
454
- window=window,
455
- center=center,
456
- pad_mode=pad_mode,
457
- freeze_parameters=True,
458
- )
459
-
460
- # Logmel feature extractor
461
- self.logmel_extractor = LogmelFilterBank(
462
- sr=sample_rate,
463
- n_fft=window_size,
464
- n_mels=mel_bins,
465
- fmin=fmin,
466
- fmax=fmax,
467
- ref=ref,
468
- amin=amin,
469
- top_db=top_db,
470
- freeze_parameters=True,
471
- )
472
-
473
- # Spec augmenter
474
- self.spec_augmenter = SpecAugmentation(
475
- time_drop_width=64,
476
- time_stripes_num=2,
477
- freq_drop_width=8,
478
- freq_stripes_num=2,
479
- )
480
-
481
- self.bn0 = nn.BatchNorm2d(64)
482
-
483
- self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64)
484
- self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128)
485
- self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256)
486
- self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512)
487
-
488
- self.fc1 = nn.Linear(512, 512, bias=True)
489
- self.fc_audioset = nn.Linear(512, classes_num, bias=True)
490
-
491
- self.init_weight()
492
-
493
- def init_weight(self):
494
- init_bn(self.bn0)
495
- init_layer(self.fc1)
496
- init_layer(self.fc_audioset)
497
-
498
- def forward(self, input, mixup_lambda=None, device=None):
499
- """
500
- Input: (batch_size, data_length)"""
501
-
502
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
503
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
504
-
505
- x = x.transpose(1, 3)
506
- x = self.bn0(x)
507
- x = x.transpose(1, 3)
508
-
509
- if self.training:
510
- x = self.spec_augmenter(x)
511
-
512
- # Mixup on spectrogram
513
- if self.training and mixup_lambda is not None:
514
- x = do_mixup(x, mixup_lambda)
515
-
516
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
517
- x = F.dropout(x, p=0.2, training=self.training)
518
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
519
- x = F.dropout(x, p=0.2, training=self.training)
520
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
521
- x = F.dropout(x, p=0.2, training=self.training)
522
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
523
- x = F.dropout(x, p=0.2, training=self.training)
524
- x = torch.mean(x, dim=3)
525
-
526
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
527
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
528
- latent_x = latent_x1 + latent_x2
529
- latent_x = latent_x.transpose(1, 2)
530
- latent_x = F.relu_(self.fc1(latent_x))
531
- latent_output = interpolate(latent_x, 16)
532
-
533
- (x1, _) = torch.max(x, dim=2)
534
- x2 = torch.mean(x, dim=2)
535
- x = x1 + x2
536
- x = F.dropout(x, p=0.5, training=self.training)
537
- x = F.relu_(self.fc1(x))
538
- embedding = F.dropout(x, p=0.5, training=self.training)
539
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
540
-
541
- output_dict = {
542
- "clipwise_output": clipwise_output,
543
- "embedding": embedding,
544
- "fine_grained_embedding": latent_output,
545
- }
546
-
547
- return output_dict
548
-
549
-
550
- class Cnn10(nn.Module):
551
- def __init__(
552
- self,
553
- sample_rate,
554
- window_size,
555
- hop_size,
556
- mel_bins,
557
- fmin,
558
- fmax,
559
- classes_num,
560
- enable_fusion=False,
561
- fusion_type="None",
562
- ):
563
- super(Cnn10, self).__init__()
564
-
565
- window = "hann"
566
- center = True
567
- pad_mode = "reflect"
568
- ref = 1.0
569
- amin = 1e-10
570
- top_db = None
571
-
572
- self.enable_fusion = enable_fusion
573
- self.fusion_type = fusion_type
574
-
575
- # Spectrogram extractor
576
- self.spectrogram_extractor = Spectrogram(
577
- n_fft=window_size,
578
- hop_length=hop_size,
579
- win_length=window_size,
580
- window=window,
581
- center=center,
582
- pad_mode=pad_mode,
583
- freeze_parameters=True,
584
- )
585
-
586
- # Logmel feature extractor
587
- self.logmel_extractor = LogmelFilterBank(
588
- sr=sample_rate,
589
- n_fft=window_size,
590
- n_mels=mel_bins,
591
- fmin=fmin,
592
- fmax=fmax,
593
- ref=ref,
594
- amin=amin,
595
- top_db=top_db,
596
- freeze_parameters=True,
597
- )
598
-
599
- # Spec augmenter
600
- self.spec_augmenter = SpecAugmentation(
601
- time_drop_width=64,
602
- time_stripes_num=2,
603
- freq_drop_width=8,
604
- freq_stripes_num=2,
605
- )
606
-
607
- self.bn0 = nn.BatchNorm2d(64)
608
-
609
- self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
610
- self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
611
- self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
612
- self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
613
- self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
614
-
615
- self.fc1 = nn.Linear(1024, 1024, bias=True)
616
- self.fc_audioset = nn.Linear(1024, classes_num, bias=True)
617
-
618
- self.init_weight()
619
-
620
- def init_weight(self):
621
- init_bn(self.bn0)
622
- init_layer(self.fc1)
623
- init_layer(self.fc_audioset)
624
-
625
- def forward(self, input, mixup_lambda=None, device=None):
626
- """
627
- Input: (batch_size, data_length)"""
628
-
629
- x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins)
630
- x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins)
631
-
632
- x = x.transpose(1, 3)
633
- x = self.bn0(x)
634
- x = x.transpose(1, 3)
635
-
636
- if self.training:
637
- x = self.spec_augmenter(x)
638
-
639
- # Mixup on spectrogram
640
- if self.training and mixup_lambda is not None:
641
- x = do_mixup(x, mixup_lambda)
642
-
643
- x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
644
- x = F.dropout(x, p=0.2, training=self.training)
645
- x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
646
- x = F.dropout(x, p=0.2, training=self.training)
647
- x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
648
- x = F.dropout(x, p=0.2, training=self.training)
649
- x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
650
- x = F.dropout(x, p=0.2, training=self.training)
651
- x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
652
- x = F.dropout(x, p=0.2, training=self.training)
653
- x = torch.mean(x, dim=3)
654
-
655
- latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
656
- latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
657
- latent_x = latent_x1 + latent_x2
658
- latent_x = latent_x.transpose(1, 2)
659
- latent_x = F.relu_(self.fc1(latent_x))
660
- latent_output = interpolate(latent_x, 32)
661
-
662
- (x1, _) = torch.max(x, dim=2)
663
- x2 = torch.mean(x, dim=2)
664
- x = x1 + x2
665
- x = F.dropout(x, p=0.5, training=self.training)
666
- x = F.relu_(self.fc1(x))
667
- embedding = F.dropout(x, p=0.5, training=self.training)
668
- clipwise_output = torch.sigmoid(self.fc_audioset(x))
669
-
670
- output_dict = {
671
- "clipwise_output": clipwise_output,
672
- "embedding": embedding,
673
- "fine_grained_embedding": latent_output,
674
- }
675
-
676
- return output_dict
677
-
678
-
679
- def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"):
680
- try:
681
- ModelProto = eval(audio_cfg.model_name)
682
- model = ModelProto(
683
- sample_rate=audio_cfg.sample_rate,
684
- window_size=audio_cfg.window_size,
685
- hop_size=audio_cfg.hop_size,
686
- mel_bins=audio_cfg.mel_bins,
687
- fmin=audio_cfg.fmin,
688
- fmax=audio_cfg.fmax,
689
- classes_num=audio_cfg.class_num,
690
- enable_fusion=enable_fusion,
691
- fusion_type=fusion_type,
692
- )
693
- return model
694
- except:
695
- raise RuntimeError(
696
- f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough."
697
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/pretrained.py DELETED
@@ -1,167 +0,0 @@
1
- import hashlib
2
- import os
3
- import urllib
4
- import warnings
5
-
6
- from tqdm import tqdm
7
-
8
- _RN50 = dict(
9
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
10
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
11
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
12
- )
13
-
14
- _RN50_quickgelu = dict(
15
- openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
16
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt",
17
- cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt",
18
- )
19
-
20
- _RN101 = dict(
21
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
22
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
23
- )
24
-
25
- _RN101_quickgelu = dict(
26
- openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
27
- yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt",
28
- )
29
-
30
- _RN50x4 = dict(
31
- openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
32
- )
33
-
34
- _RN50x16 = dict(
35
- openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
36
- )
37
-
38
- _RN50x64 = dict(
39
- openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
40
- )
41
-
42
- _VITB32 = dict(
43
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
44
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
45
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
46
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
47
- )
48
-
49
- _VITB32_quickgelu = dict(
50
- openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
51
- laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt",
52
- laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt",
53
- laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt",
54
- )
55
-
56
- _VITB16 = dict(
57
- openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
58
- )
59
-
60
- _VITL14 = dict(
61
- openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
62
- )
63
-
64
- _PRETRAINED = {
65
- "RN50": _RN50,
66
- "RN50-quickgelu": _RN50_quickgelu,
67
- "RN101": _RN101,
68
- "RN101-quickgelu": _RN101_quickgelu,
69
- "RN50x4": _RN50x4,
70
- "RN50x16": _RN50x16,
71
- "ViT-B-32": _VITB32,
72
- "ViT-B-32-quickgelu": _VITB32_quickgelu,
73
- "ViT-B-16": _VITB16,
74
- "ViT-L-14": _VITL14,
75
- }
76
-
77
-
78
- def list_pretrained(as_str: bool = False):
79
- """returns list of pretrained models
80
- Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True
81
- """
82
- return [
83
- ":".join([k, t]) if as_str else (k, t)
84
- for k in _PRETRAINED.keys()
85
- for t in _PRETRAINED[k].keys()
86
- ]
87
-
88
-
89
- def list_pretrained_tag_models(tag: str):
90
- """return all models having the specified pretrain tag"""
91
- models = []
92
- for k in _PRETRAINED.keys():
93
- if tag in _PRETRAINED[k]:
94
- models.append(k)
95
- return models
96
-
97
-
98
- def list_pretrained_model_tags(model: str):
99
- """return all pretrain tags for the specified model architecture"""
100
- tags = []
101
- if model in _PRETRAINED:
102
- tags.extend(_PRETRAINED[model].keys())
103
- return tags
104
-
105
-
106
- def get_pretrained_url(model: str, tag: str):
107
- if model not in _PRETRAINED:
108
- return ""
109
- model_pretrained = _PRETRAINED[model]
110
- if tag not in model_pretrained:
111
- return ""
112
- return model_pretrained[tag]
113
-
114
-
115
- def download_pretrained(url: str, root: str = os.path.expanduser("~/.cache/clip")):
116
- os.makedirs(root, exist_ok=True)
117
- filename = os.path.basename(url)
118
-
119
- if "openaipublic" in url:
120
- expected_sha256 = url.split("/")[-2]
121
- else:
122
- expected_sha256 = ""
123
-
124
- download_target = os.path.join(root, filename)
125
-
126
- if os.path.exists(download_target) and not os.path.isfile(download_target):
127
- raise RuntimeError(f"{download_target} exists and is not a regular file")
128
-
129
- if os.path.isfile(download_target):
130
- if expected_sha256:
131
- if (
132
- hashlib.sha256(open(download_target, "rb").read()).hexdigest()
133
- == expected_sha256
134
- ):
135
- return download_target
136
- else:
137
- warnings.warn(
138
- f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
139
- )
140
- else:
141
- return download_target
142
-
143
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
144
- with tqdm(
145
- total=int(source.info().get("Content-Length")),
146
- ncols=80,
147
- unit="iB",
148
- unit_scale=True,
149
- ) as loop:
150
- while True:
151
- buffer = source.read(8192)
152
- if not buffer:
153
- break
154
-
155
- output.write(buffer)
156
- loop.update(len(buffer))
157
-
158
- if (
159
- expected_sha256
160
- and hashlib.sha256(open(download_target, "rb").read()).hexdigest()
161
- != expected_sha256
162
- ):
163
- raise RuntimeError(
164
- f"Model has been downloaded but the SHA256 checksum does not not match"
165
- )
166
-
167
- return download_target
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/timm_model.py DELETED
@@ -1,112 +0,0 @@
1
- """ timm model adapter
2
-
3
- Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model.
4
- """
5
- from collections import OrderedDict
6
-
7
- import torch.nn as nn
8
-
9
- try:
10
- import timm
11
- from timm.models.layers import Mlp, to_2tuple
12
- from timm.models.layers.attention_pool2d import RotAttentionPool2d
13
- from timm.models.layers.attention_pool2d import (
14
- AttentionPool2d as AbsAttentionPool2d,
15
- )
16
- except ImportError:
17
- timm = None
18
-
19
- from .utils import freeze_batch_norm_2d
20
-
21
-
22
- class TimmModel(nn.Module):
23
- """timm model adapter
24
- # FIXME this adapter is a work in progress, may change in ways that break weight compat
25
- """
26
-
27
- def __init__(
28
- self,
29
- model_name,
30
- embed_dim,
31
- image_size=224,
32
- pool="avg",
33
- proj="linear",
34
- drop=0.0,
35
- pretrained=False,
36
- ):
37
- super().__init__()
38
- if timm is None:
39
- raise RuntimeError("Please `pip install timm` to use timm models.")
40
-
41
- self.image_size = to_2tuple(image_size)
42
- self.trunk = timm.create_model(model_name, pretrained=pretrained)
43
- feat_size = self.trunk.default_cfg.get("pool_size", None)
44
- feature_ndim = 1 if not feat_size else 2
45
- if pool in ("abs_attn", "rot_attn"):
46
- assert feature_ndim == 2
47
- # if attn pooling used, remove both classifier and default pool
48
- self.trunk.reset_classifier(0, global_pool="")
49
- else:
50
- # reset global pool if pool config set, otherwise leave as network default
51
- reset_kwargs = dict(global_pool=pool) if pool else {}
52
- self.trunk.reset_classifier(0, **reset_kwargs)
53
- prev_chs = self.trunk.num_features
54
-
55
- head_layers = OrderedDict()
56
- if pool == "abs_attn":
57
- head_layers["pool"] = AbsAttentionPool2d(
58
- prev_chs, feat_size=feat_size, out_features=embed_dim
59
- )
60
- prev_chs = embed_dim
61
- elif pool == "rot_attn":
62
- head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim)
63
- prev_chs = embed_dim
64
- else:
65
- assert proj, "projection layer needed if non-attention pooling is used."
66
-
67
- # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used
68
- if proj == "linear":
69
- head_layers["drop"] = nn.Dropout(drop)
70
- head_layers["proj"] = nn.Linear(prev_chs, embed_dim)
71
- elif proj == "mlp":
72
- head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop)
73
-
74
- self.head = nn.Sequential(head_layers)
75
-
76
- def lock(self, unlocked_groups=0, freeze_bn_stats=False):
77
- """lock modules
78
- Args:
79
- unlocked_groups (int): leave last n layer groups unlocked (default: 0)
80
- """
81
- if not unlocked_groups:
82
- # lock full model
83
- for param in self.trunk.parameters():
84
- param.requires_grad = False
85
- if freeze_bn_stats:
86
- freeze_batch_norm_2d(self.trunk)
87
- else:
88
- # NOTE: partial freeze requires latest timm (master) branch and is subject to change
89
- try:
90
- # FIXME import here until API stable and in an official release
91
- from timm.models.helpers import group_parameters, group_modules
92
- except ImportError:
93
- raise RuntimeError(
94
- "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`"
95
- )
96
- matcher = self.trunk.group_matcher()
97
- gparams = group_parameters(self.trunk, matcher)
98
- max_layer_id = max(gparams.keys())
99
- max_layer_id = max_layer_id - unlocked_groups
100
- for group_idx in range(max_layer_id + 1):
101
- group = gparams[group_idx]
102
- for param in group:
103
- self.trunk.get_parameter(param).requires_grad = False
104
- if freeze_bn_stats:
105
- gmodules = group_modules(self.trunk, matcher, reverse=True)
106
- gmodules = {k for k, v in gmodules.items() if v <= max_layer_id}
107
- freeze_batch_norm_2d(self.trunk, gmodules)
108
-
109
- def forward(self, x):
110
- x = self.trunk(x)
111
- x = self.head(x)
112
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/tokenizer.py DELETED
@@ -1,197 +0,0 @@
1
- """ CLIP tokenizer
2
-
3
- Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
4
- """
5
- import gzip
6
- import html
7
- import os
8
- from functools import lru_cache
9
- from typing import Union, List
10
-
11
- import ftfy
12
- import regex as re
13
- import torch
14
-
15
-
16
- @lru_cache()
17
- def default_bpe():
18
- return os.path.join(
19
- os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz"
20
- )
21
-
22
-
23
- @lru_cache()
24
- def bytes_to_unicode():
25
- """
26
- Returns list of utf-8 byte and a corresponding list of unicode strings.
27
- The reversible bpe codes work on unicode strings.
28
- This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
29
- When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
30
- This is a signficant percentage of your normal, say, 32K bpe vocab.
31
- To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
32
- And avoids mapping to whitespace/control characters the bpe code barfs on.
33
- """
34
- bs = (
35
- list(range(ord("!"), ord("~") + 1))
36
- + list(range(ord("¡"), ord("¬") + 1))
37
- + list(range(ord("®"), ord("ÿ") + 1))
38
- )
39
- cs = bs[:]
40
- n = 0
41
- for b in range(2**8):
42
- if b not in bs:
43
- bs.append(b)
44
- cs.append(2**8 + n)
45
- n += 1
46
- cs = [chr(n) for n in cs]
47
- return dict(zip(bs, cs))
48
-
49
-
50
- def get_pairs(word):
51
- """Return set of symbol pairs in a word.
52
- Word is represented as tuple of symbols (symbols being variable-length strings).
53
- """
54
- pairs = set()
55
- prev_char = word[0]
56
- for char in word[1:]:
57
- pairs.add((prev_char, char))
58
- prev_char = char
59
- return pairs
60
-
61
-
62
- def basic_clean(text):
63
- text = ftfy.fix_text(text)
64
- text = html.unescape(html.unescape(text))
65
- return text.strip()
66
-
67
-
68
- def whitespace_clean(text):
69
- text = re.sub(r"\s+", " ", text)
70
- text = text.strip()
71
- return text
72
-
73
-
74
- class SimpleTokenizer(object):
75
- def __init__(self, bpe_path: str = default_bpe(), special_tokens=None):
76
- self.byte_encoder = bytes_to_unicode()
77
- self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
78
- merges = gzip.open(bpe_path).read().decode("utf-8").split("\n")
79
- merges = merges[1 : 49152 - 256 - 2 + 1]
80
- merges = [tuple(merge.split()) for merge in merges]
81
- vocab = list(bytes_to_unicode().values())
82
- vocab = vocab + [v + "</w>" for v in vocab]
83
- for merge in merges:
84
- vocab.append("".join(merge))
85
- if not special_tokens:
86
- special_tokens = ["<start_of_text>", "<end_of_text>"]
87
- else:
88
- special_tokens = ["<start_of_text>", "<end_of_text>"] + special_tokens
89
- vocab.extend(special_tokens)
90
- self.encoder = dict(zip(vocab, range(len(vocab))))
91
- self.decoder = {v: k for k, v in self.encoder.items()}
92
- self.bpe_ranks = dict(zip(merges, range(len(merges))))
93
- self.cache = {t: t for t in special_tokens}
94
- special = "|".join(special_tokens)
95
- self.pat = re.compile(
96
- special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
97
- re.IGNORECASE,
98
- )
99
-
100
- self.vocab_size = len(self.encoder)
101
- self.all_special_ids = [self.encoder[t] for t in special_tokens]
102
-
103
- def bpe(self, token):
104
- if token in self.cache:
105
- return self.cache[token]
106
- word = tuple(token[:-1]) + (token[-1] + "</w>",)
107
- pairs = get_pairs(word)
108
-
109
- if not pairs:
110
- return token + "</w>"
111
-
112
- while True:
113
- bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
114
- if bigram not in self.bpe_ranks:
115
- break
116
- first, second = bigram
117
- new_word = []
118
- i = 0
119
- while i < len(word):
120
- try:
121
- j = word.index(first, i)
122
- new_word.extend(word[i:j])
123
- i = j
124
- except:
125
- new_word.extend(word[i:])
126
- break
127
-
128
- if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
129
- new_word.append(first + second)
130
- i += 2
131
- else:
132
- new_word.append(word[i])
133
- i += 1
134
- new_word = tuple(new_word)
135
- word = new_word
136
- if len(word) == 1:
137
- break
138
- else:
139
- pairs = get_pairs(word)
140
- word = " ".join(word)
141
- self.cache[token] = word
142
- return word
143
-
144
- def encode(self, text):
145
- bpe_tokens = []
146
- text = whitespace_clean(basic_clean(text)).lower()
147
- for token in re.findall(self.pat, text):
148
- token = "".join(self.byte_encoder[b] for b in token.encode("utf-8"))
149
- bpe_tokens.extend(
150
- self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")
151
- )
152
- return bpe_tokens
153
-
154
- def decode(self, tokens):
155
- text = "".join([self.decoder[token] for token in tokens])
156
- text = (
157
- bytearray([self.byte_decoder[c] for c in text])
158
- .decode("utf-8", errors="replace")
159
- .replace("</w>", " ")
160
- )
161
- return text
162
-
163
-
164
- _tokenizer = SimpleTokenizer()
165
-
166
-
167
- def tokenize(
168
- texts: Union[str, List[str]], context_length: int = 77
169
- ) -> torch.LongTensor:
170
- """
171
- Returns the tokenized representation of given input string(s)
172
-
173
- Parameters
174
- ----------
175
- texts : Union[str, List[str]]
176
- An input string or a list of input strings to tokenize
177
- context_length : int
178
- The context length to use; all CLIP models use 77 as the context length
179
-
180
- Returns
181
- -------
182
- A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
183
- """
184
- if isinstance(texts, str):
185
- texts = [texts]
186
-
187
- sot_token = _tokenizer.encoder["<start_of_text>"]
188
- eot_token = _tokenizer.encoder["<end_of_text>"]
189
- all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
190
- result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
191
-
192
- for i, tokens in enumerate(all_tokens):
193
- if len(tokens) > context_length:
194
- tokens = tokens[:context_length] # Truncate
195
- result[i, : len(tokens)] = torch.tensor(tokens)
196
-
197
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/transform.py DELETED
@@ -1,45 +0,0 @@
1
- from torchvision.transforms import (
2
- Normalize,
3
- Compose,
4
- RandomResizedCrop,
5
- InterpolationMode,
6
- ToTensor,
7
- Resize,
8
- CenterCrop,
9
- )
10
-
11
-
12
- def _convert_to_rgb(image):
13
- return image.convert("RGB")
14
-
15
-
16
- def image_transform(
17
- image_size: int,
18
- is_train: bool,
19
- mean=(0.48145466, 0.4578275, 0.40821073),
20
- std=(0.26862954, 0.26130258, 0.27577711),
21
- ):
22
- normalize = Normalize(mean=mean, std=std)
23
- if is_train:
24
- return Compose(
25
- [
26
- RandomResizedCrop(
27
- image_size,
28
- scale=(0.9, 1.0),
29
- interpolation=InterpolationMode.BICUBIC,
30
- ),
31
- _convert_to_rgb,
32
- ToTensor(),
33
- normalize,
34
- ]
35
- )
36
- else:
37
- return Compose(
38
- [
39
- Resize(image_size, interpolation=InterpolationMode.BICUBIC),
40
- CenterCrop(image_size),
41
- _convert_to_rgb,
42
- ToTensor(),
43
- normalize,
44
- ]
45
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/open_clip/utils.py DELETED
@@ -1,356 +0,0 @@
1
- import numpy as np
2
- import torch
3
- from torch import nn as nn
4
- from torchvision.ops.misc import FrozenBatchNorm2d
5
- import logging
6
- import h5py
7
- from tqdm import tqdm
8
- import random
9
- import json
10
- import os
11
- import pathlib
12
-
13
- # TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later.
14
- dataset_split = {
15
- "audiocaps": ["train", "valid", "test"],
16
- "audioset": ["balanced_train", "unbalanced_train", "eval"],
17
- "BBCSoundEffects": ["train", "test"],
18
- "Clotho": ["train", "test", "valid"],
19
- "free_to_use_sounds": ["train", "test"],
20
- "paramount_motion": ["train", "test"],
21
- "sonniss_game_effects": ["train", "test"],
22
- "wesoundeffects": ["train", "test"],
23
- "MACS": ["train", "test"],
24
- "freesound": ["train", "test"],
25
- "FSD50K": ["train", "test", "valid"],
26
- "fsd50k_class_label": ["train", "test", "valid"],
27
- "esc50": ["train", "test"],
28
- "audiostock": ["train", "test"],
29
- "freesound_no_overlap_noesc50": ["train", "test"],
30
- "epidemic_sound_effects": ["train", "test"],
31
- "VGGSound": ["train", "test"],
32
- "urbansound8k_class_label": ["train", "test"],
33
- "audioset_t5": ["balanced_train", "unbalanced_train", "eval"],
34
- "epidemic_sound_effects_t5": ["train", "test"],
35
- "WavText5K": ["train", "test"],
36
- "esc50_no_overlap": ["train", "test"],
37
- "usd8k_no_overlap": ["train", "test"],
38
- "fsd50k_200_class_label": ["train", "test", "valid"],
39
- }
40
-
41
-
42
- def freeze_batch_norm_2d(module, module_match={}, name=""):
43
- """
44
- Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
45
- itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
46
- returned. Otherwise, the module is walked recursively and submodules are converted in place.
47
-
48
- Args:
49
- module (torch.nn.Module): Any PyTorch module.
50
- module_match (dict): Dictionary of full module names to freeze (all if empty)
51
- name (str): Full module name (prefix)
52
-
53
- Returns:
54
- torch.nn.Module: Resulting module
55
-
56
- Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
57
- """
58
- res = module
59
- is_match = True
60
- if module_match:
61
- is_match = name in module_match
62
- if is_match and isinstance(
63
- module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)
64
- ):
65
- res = FrozenBatchNorm2d(module.num_features)
66
- res.num_features = module.num_features
67
- res.affine = module.affine
68
- if module.affine:
69
- res.weight.data = module.weight.data.clone().detach()
70
- res.bias.data = module.bias.data.clone().detach()
71
- res.running_mean.data = module.running_mean.data
72
- res.running_var.data = module.running_var.data
73
- res.eps = module.eps
74
- else:
75
- for child_name, child in module.named_children():
76
- full_child_name = ".".join([name, child_name]) if name else child_name
77
- new_child = freeze_batch_norm_2d(child, module_match, full_child_name)
78
- if new_child is not child:
79
- res.add_module(child_name, new_child)
80
- return res
81
-
82
-
83
- def exist(dataset_name, dataset_type):
84
- """
85
- Check if dataset exists
86
- """
87
- if dataset_type in dataset_split[dataset_name]:
88
- return True
89
- else:
90
- return False
91
-
92
-
93
- def get_tar_path_from_dataset_name(
94
- dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None
95
- ):
96
- """
97
- Get tar path from dataset name and type
98
- """
99
- output = []
100
- for n in dataset_names:
101
- if full_dataset is not None and n in full_dataset:
102
- current_dataset_types = dataset_split[n]
103
- else:
104
- current_dataset_types = dataset_types
105
- for s in current_dataset_types:
106
- tmp = []
107
- if islocal:
108
- sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json"
109
- if not os.path.exists(sizefilepath_):
110
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
111
- else:
112
- sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
113
- if not os.path.exists(sizefilepath_):
114
- continue
115
- sizes = json.load(open(sizefilepath_, "r"))
116
- for k in sizes.keys():
117
- if islocal:
118
- tmp.append(f"{dataset_path}/{n}/{s}/{k}")
119
- else:
120
- tmp.append(
121
- f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -"
122
- )
123
- if proportion != 1:
124
- tmp = random.sample(tmp, int(proportion * len(tmp)))
125
- output.append(tmp)
126
- return sum(output, [])
127
-
128
-
129
- def get_tar_path_from_txts(txt_path, islocal, proportion=1):
130
- """
131
- Get tar path from txt path
132
- """
133
- if isinstance(txt_path, (list, tuple)):
134
- return sum(
135
- [
136
- get_tar_path_from_txts(
137
- txt_path[i], islocal=islocal, proportion=proportion
138
- )
139
- for i in range(len(txt_path))
140
- ],
141
- [],
142
- )
143
- if isinstance(txt_path, str):
144
- with open(txt_path) as f:
145
- lines = f.readlines()
146
- if islocal:
147
- lines = [
148
- lines[i]
149
- .split("\n")[0]
150
- .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/")
151
- for i in range(len(lines))
152
- ]
153
- else:
154
- lines = [
155
- lines[i].split("\n")[0].replace(".tar", ".tar -")
156
- for i in range(len(lines))
157
- ]
158
- if proportion != 1:
159
- print("Sampling tars with proportion of {}".format(proportion))
160
- lines = random.sample(lines, int(proportion * len(lines)))
161
- return lines
162
-
163
-
164
- def get_mix_lambda(mixup_alpha, batch_size):
165
- mixup_lambdas = [
166
- np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size)
167
- ]
168
- return np.array(mixup_lambdas).astype(np.float32)
169
-
170
-
171
- def do_mixup(x, mixup_lambda):
172
- """
173
- Args:
174
- x: (batch_size , ...)
175
- mixup_lambda: (batch_size,)
176
- Returns:
177
- out: (batch_size, ...)
178
- """
179
- out = (
180
- x.transpose(0, -1) * mixup_lambda
181
- + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda)
182
- ).transpose(0, -1)
183
- return out
184
-
185
-
186
- def interpolate(x, ratio):
187
- """Interpolate data in time domain. This is used to compensate the
188
- resolution reduction in downsampling of a CNN.
189
-
190
- Args:
191
- x: (batch_size, time_steps, classes_num)
192
- ratio: int, ratio to interpolate
193
- Returns:
194
- upsampled: (batch_size, time_steps * ratio, classes_num)
195
- """
196
- (batch_size, time_steps, classes_num) = x.shape
197
- upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
198
- upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
199
- return upsampled
200
-
201
-
202
- def pad_framewise_output(framewise_output, frames_num):
203
- """Pad framewise_output to the same length as input frames. The pad value
204
- is the same as the value of the last frame.
205
- Args:
206
- framewise_output: (batch_size, frames_num, classes_num)
207
- frames_num: int, number of frames to pad
208
- Outputs:
209
- output: (batch_size, frames_num, classes_num)
210
- """
211
- pad = framewise_output[:, -1:, :].repeat(
212
- 1, frames_num - framewise_output.shape[1], 1
213
- )
214
- """tensor for padding"""
215
-
216
- output = torch.cat((framewise_output, pad), dim=1)
217
- """(batch_size, frames_num, classes_num)"""
218
-
219
-
220
- def process_ipc(index_path, classes_num, filename):
221
- # load data
222
- logging.info("Load Data...............")
223
- ipc = [[] for _ in range(classes_num)]
224
- with h5py.File(index_path, "r") as f:
225
- for i in tqdm(range(len(f["target"]))):
226
- t_class = np.where(f["target"][i])[0]
227
- for t in t_class:
228
- ipc[t].append(i)
229
- print(ipc)
230
- np.save(filename, ipc)
231
- logging.info("Load Data Succeed...............")
232
-
233
-
234
- def save_to_dict(s, o_={}):
235
- sp = s.split(": ")
236
- o_.update({sp[0]: float(sp[1])})
237
- return o_
238
-
239
-
240
- def get_data_from_log(txt_path):
241
- """
242
- Output dictionary from out.txt log file
243
- """
244
- with open(txt_path) as f:
245
- lines = f.readlines()
246
- val_data = {}
247
- train_data = {}
248
- train_losses = []
249
- train_losses_epoch = []
250
- for i in range(len(lines)):
251
- if "| INFO |" in lines[i]:
252
- if "Eval Epoch" in lines[i]:
253
- if "val_loss" in lines[i]:
254
- # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", ""))
255
- line = lines[i].split("Eval Epoch: ")[-1]
256
- num_epoch = int(line.split(" ")[0].split(" ")[0])
257
- d = {
258
- line.split(" ")[0]
259
- .split(" ")[1]
260
- .replace(":", ""): float(line.split(" ")[0].split(" ")[-1])
261
- }
262
- for i in range(1, len(line.split(" "))):
263
- d = save_to_dict(line.split(" ")[i], d)
264
- val_data[num_epoch] = d
265
- elif "Train Epoch" in lines[i]:
266
- num_epoch = int(lines[i].split("Train Epoch: ")[1][0])
267
- loss = float(lines[i].split("Loss: ")[-1].split(" (")[0])
268
- train_losses.append(loss)
269
- train_losses_epoch.append(num_epoch)
270
- for i in range(len(train_losses)):
271
- train_data[i] = {
272
- "num_epoch": train_losses_epoch[i],
273
- "train_loss": train_losses[i],
274
- }
275
- return train_data, val_data
276
-
277
-
278
- def save_p(obj, filename):
279
- import pickle
280
-
281
- try:
282
- from deepdiff import DeepDiff
283
- except:
284
- os.system("pip install deepdiff")
285
- from deepdiff import DeepDiff
286
- with open(filename, "wb") as file:
287
- pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol
288
- with open(filename, "rb") as file:
289
- z = pickle.load(file)
290
- assert (
291
- DeepDiff(obj, z, ignore_string_case=True) == {}
292
- ), "there is something wrong with the saving process"
293
- return
294
-
295
-
296
- def load_p(filename):
297
- import pickle
298
-
299
- with open(filename, "rb") as file:
300
- z = pickle.load(file)
301
- return z
302
-
303
-
304
- def save_json(data, name="data.json"):
305
- import json
306
-
307
- with open(name, "w") as fp:
308
- json.dump(data, fp)
309
- return
310
-
311
-
312
- def load_json(name):
313
- import json
314
-
315
- with open(name, "r") as fp:
316
- data = json.load(fp)
317
- return data
318
-
319
-
320
- def load_class_label(path):
321
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
322
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
323
- out = None
324
- if path is not None:
325
- if pathlib.Path(path).suffix in [".pkl", ".pickle"]:
326
- out = load_p(path)
327
- elif pathlib.Path(path).suffix in [".json", ".txt"]:
328
- out = load_json(path)
329
- elif pathlib.Path(path).suffix in [".npy", ".npz"]:
330
- out = np.load(path)
331
- elif pathlib.Path(path).suffix in [".csv"]:
332
- import pandas as pd
333
-
334
- out = pd.read_csv(path)
335
- return out
336
- # if out is None:
337
- # return None
338
- # else:
339
- # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False)
340
- # val = Array('i', out.values(), lock=False)
341
- # return (key, val)
342
-
343
-
344
- from torch import optim
345
-
346
-
347
- def get_optimizer(params, lr, betas, eps, momentum, optimizer_name):
348
- if optimizer_name.lower() == "adamw":
349
- optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps)
350
- elif optimizer_name.lower() == "sgd":
351
- optimizer = optim.SGD(params, lr=lr, momentum=momentum)
352
- elif optimizer_name.lower() == "adam":
353
- optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps)
354
- else:
355
- raise ValueError("optimizer name is not correct")
356
- return optimizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/training/__init__.py DELETED
File without changes
audioldm2/clap/training/audioset_textmap.npy DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b
3
- size 84448
 
 
 
 
audioldm2/clap/training/bpe_simple_vocab_16e6.txt.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
3
- size 1356917
 
 
 
 
audioldm2/clap/training/data.py DELETED
@@ -1,865 +0,0 @@
1
- import json
2
- import logging
3
- import os
4
- import random
5
- import h5py
6
- from dataclasses import dataclass
7
- import numpy as np
8
- import pandas as pd
9
- import torch
10
- import torchvision.datasets as datasets
11
- from PIL import Image
12
- from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
13
- from torch.utils.data.distributed import DistributedSampler
14
- import soundfile as sf
15
- import io
16
- from pathlib import Path
17
- # import wget
18
-
19
- from audioldm2.clap.open_clip.utils import get_tar_path_from_dataset_name
20
- from audioldm2.clap.open_clip.utils import load_class_label
21
-
22
- try:
23
- import horovod.torch as hvd
24
- except ImportError:
25
- hvd = None
26
-
27
- try:
28
- import torchaudio
29
- except ImportError:
30
- torchaudio = None
31
-
32
- from audioldm2.clap.open_clip import tokenize
33
-
34
-
35
- def tokenizer(text):
36
- return tokenize(text).squeeze(0)
37
-
38
-
39
- from transformers import RobertaTokenizer
40
-
41
- tokenize = RobertaTokenizer.from_pretrained("roberta-base")
42
-
43
-
44
- def tokenizer(text):
45
- result = tokenize(
46
- text,
47
- padding="max_length",
48
- truncation=True,
49
- max_length=77,
50
- return_tensors="pt",
51
- )
52
- return {k: v.squeeze(0) for k, v in result.items()}
53
-
54
-
55
- # initizlied the audioset map
56
- _AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy")
57
- _AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True)
58
-
59
-
60
- def int16_to_float32(x):
61
- return (x / 32767.0).astype(np.float32)
62
-
63
-
64
- def float32_to_int16(x):
65
- x = np.clip(x, a_min=-1.0, a_max=1.0)
66
- return (x * 32767.0).astype(np.int16)
67
-
68
-
69
- # For Toy Dataset
70
- class ToyDataset(Dataset):
71
- def __init__(self, index_path, ipc, config, eval_mode=False):
72
- """Toy Dataset for testing the audioset input with text labels
73
- Parameters
74
- ----------
75
- index_path: str
76
- the link to the h5 file of each audio
77
- idc: str
78
- the link to the npy file, the number of samples in each class
79
- config: dict
80
- the audio cfg file
81
- eval_model (bool): to indicate if the dataset is a testing dataset
82
- """
83
- self.audio_cfg = config["audio_cfg"]
84
- self.text_cfg = config["text_cfg"]
85
- self.fp = h5py.File(index_path, "r")
86
- self.ipc = np.load(ipc, allow_pickle=True)
87
- self.total_size = len(self.fp["audio_name"])
88
- self.classes_num = self.audio_cfg["class_num"]
89
- self.eval_mode = eval_mode
90
-
91
- if not eval_mode:
92
- self.generate_queue()
93
- else:
94
- self.queue = []
95
- for i in range(self.total_size):
96
- target = self.fp["target"][i]
97
- if np.sum(target) > 0:
98
- self.queue.append(i)
99
- self.total_size = len(self.queue)
100
- logging.info("total dataset size: %d" % (self.total_size))
101
- logging.info("class num: %d" % (self.classes_num))
102
-
103
- def time_shifting(self, x):
104
- frame_num = len(x)
105
- shift_len = random.randint(0, frame_num - 1)
106
- new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0)
107
- return new_sample
108
-
109
- def generate_queue(self):
110
- self.queue = []
111
- while len(self.queue) < self.total_size:
112
- class_set = [*range(self.classes_num)]
113
- random.shuffle(class_set)
114
- self.queue += [
115
- self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set
116
- ]
117
- self.queue = self.queue[: self.total_size]
118
-
119
- logging.info("queue regenerated:%s" % (self.queue[-5:]))
120
-
121
- def crop_wav(self, x):
122
- crop_size = self.audio_cfg["crop_size"]
123
- crop_pos = random.randint(0, len(x) - crop_size - 1)
124
- return x[crop_pos : crop_pos + crop_size]
125
-
126
- def prompt_text(self, target):
127
- events = _AUDIOSET_MAP[np.where(target > 0)]
128
- event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1]
129
- text = tokenize(event_text)[0]
130
- return text
131
-
132
- def __getitem__(self, index):
133
- """Load waveform, text, and target of an audio clip
134
-
135
- Parameters
136
- ----------
137
- index: int
138
- the index number
139
- Return
140
- ------
141
- output: dict {
142
- "hdf5_path": str,
143
- "index_in_hdf5": int,
144
- "audio_name": str,
145
- "waveform": list (audio_length,),
146
- "target": list (class_num, ),
147
- "text": torch.tensor (context_length,)
148
- }
149
- the output dictionary
150
- """
151
- s_index = self.queue[index]
152
-
153
- audio_name = self.fp["audio_name"][s_index].decode()
154
- # Hardcode here CHANGE
155
- hdf5_path = (
156
- self.fp["hdf5_path"][s_index]
157
- .decode()
158
- .replace(
159
- "../workspace",
160
- "/home/la/kechen/Research/ke_zsasp/workspace",
161
- )
162
- )
163
- r_idx = self.fp["index_in_hdf5"][s_index]
164
- target = self.fp["target"][s_index].astype(np.float32)
165
- text = self.prompt_text(target)
166
- with h5py.File(hdf5_path, "r") as f:
167
- waveform = int16_to_float32(f["waveform"][r_idx])[
168
- : self.audio_cfg["clip_samples"]
169
- ]
170
- assert (
171
- len(waveform) == self.audio_cfg["clip_samples"]
172
- ), "The sample length is not match"
173
- # Time shift
174
- # if (self.config.enable_time_shift) and (not self.eval_mode):
175
- # waveform = self.time_shifting(waveform)
176
- # # Label Enhance
177
- # if (self.config.crop_size is not None) and (not self.eval_mode):
178
- # waveform = self.crop_wav(waveform)
179
- # # the label enhance rate is fixed 0.5
180
- # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5:
181
- # kidx = np.where(target)[0]
182
- # for k in kidx:
183
- # for add_key in self.class_map[k][1]:
184
- # target[add_key] = 1.0
185
- # if len(self.class_map[k][2]) > 0:
186
- # add_key = random.choice(self.class_map[k][2])
187
- # target[add_key] = 1.0
188
-
189
- # missing the text input
190
- mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :]
191
- mel_spec = (
192
- torch.cat(
193
- [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0
194
- )
195
- .cpu()
196
- .numpy()
197
- )
198
- longer = random.choice([True, False])
199
- if longer == False:
200
- mel_spec[1:, :, :] = 0.0
201
- data_dict = {
202
- "hdf5_path": hdf5_path,
203
- "index_in_hdf5": r_idx,
204
- "audio_name": audio_name,
205
- "waveform": waveform,
206
- "class_label": target,
207
- "text": text,
208
- "longer": longer,
209
- "mel_fusion": mel_spec,
210
- }
211
- return data_dict
212
-
213
- def __len__(self):
214
- return self.total_size
215
-
216
-
217
- class CsvDataset(Dataset):
218
- def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"):
219
- logging.debug(f"Loading csv data from {input_filename}.")
220
- df = pd.read_csv(input_filename, sep=sep)
221
-
222
- self.images = df[img_key].tolist()
223
- self.captions = df[caption_key].tolist()
224
- self.transforms = transforms
225
- logging.debug("Done loading data.")
226
-
227
- def __len__(self):
228
- return len(self.captions)
229
-
230
- def __getitem__(self, idx):
231
- images = self.transforms(Image.open(str(self.images[idx])))
232
- texts = tokenize([str(self.captions[idx])])[0]
233
- return images, texts
234
-
235
-
236
- @dataclass
237
- class DataInfo:
238
- dataloader: DataLoader
239
- sampler: DistributedSampler
240
-
241
-
242
- def preprocess_txt(text):
243
- return tokenize([str(text)])[0]
244
-
245
-
246
- # def get_dataset_size(shards, sizefilepath_=None, is_local=True):
247
- # if isinstance(shards, list):
248
- # size_list = []
249
- # for s in shards:
250
- # size_list.append(
251
- # get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0]
252
- # )
253
- # else:
254
- # if not is_local:
255
- # for n in dataset_split.keys():
256
- # if n in shards.split("/"):
257
- # break
258
- # for s in dataset_split[n]:
259
- # if s in shards.split("/"):
260
- # break
261
- # sizefilepath_ = f"./json_files/{n}/{s}/sizes.json"
262
- # shards_list = list(braceexpand.braceexpand(shards))
263
- # dir_path = os.path.dirname(shards)
264
- # if sizefilepath_ is not None:
265
- # sizes = json.load(open(sizefilepath_, "r"))
266
- # total_size = sum(
267
- # [
268
- # int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))])
269
- # for shard in shards_list
270
- # ]
271
- # )
272
- # else:
273
- # sizes_filename = os.path.join(dir_path, "sizes.json")
274
- # len_filename = os.path.join(dir_path, "__len__")
275
- # if os.path.exists(sizes_filename):
276
- # sizes = json.load(open(sizes_filename, "r"))
277
- # total_size = sum(
278
- # [int(sizes[os.path.basename(shard)]) for shard in shards_list]
279
- # )
280
- # elif os.path.exists(len_filename):
281
- # # FIXME this used to be eval(open(...)) but that seemed rather unsafe
282
- # total_size = ast.literal_eval(open(len_filename, "r").read())
283
- # else:
284
- # raise Exception(
285
- # "Cannot find sizes file for dataset. Please specify the path to the file."
286
- # )
287
- # # total_size = None # num samples undefined
288
- # # some common dataset sizes (at time of authors last download)
289
- # # cc3m-train: 2905954
290
- # # cc12m: 10968539
291
- # # LAION-400m: 407332084
292
- # num_shards = len(shards_list)
293
- # if isinstance(shards, list):
294
- # return sum(size_list), len(shards)
295
- # else:
296
- # return total_size, num_shards
297
-
298
-
299
- def get_imagenet(args, preprocess_fns, split):
300
- assert split in ["train", "val", "v2"]
301
- is_train = split == "train"
302
- preprocess_train, preprocess_val = preprocess_fns
303
-
304
- if split == "v2":
305
- from imagenetv2_pytorch import ImageNetV2Dataset
306
-
307
- dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val)
308
- else:
309
- if is_train:
310
- data_path = args.imagenet_train
311
- preprocess_fn = preprocess_train
312
- else:
313
- data_path = args.imagenet_val
314
- preprocess_fn = preprocess_val
315
- assert data_path
316
-
317
- dataset = datasets.ImageFolder(data_path, transform=preprocess_fn)
318
-
319
- if is_train:
320
- idxs = np.zeros(len(dataset.targets))
321
- target_array = np.array(dataset.targets)
322
- k = 50
323
- for c in range(1000):
324
- m = target_array == c
325
- n = len(idxs[m])
326
- arr = np.zeros(n)
327
- arr[:k] = 1
328
- np.random.shuffle(arr)
329
- idxs[m] = arr
330
-
331
- idxs = idxs.astype("int")
332
- sampler = SubsetRandomSampler(np.where(idxs)[0])
333
- else:
334
- sampler = None
335
-
336
- dataloader = torch.utils.data.DataLoader(
337
- dataset,
338
- batch_size=args.batch_size,
339
- num_workers=args.workers,
340
- sampler=sampler,
341
- )
342
-
343
- return DataInfo(dataloader, sampler)
344
-
345
-
346
- def count_samples(dataloader):
347
- os.environ["WDS_EPOCH"] = "0"
348
- n_elements, n_batches = 0, 0
349
- for images, texts in dataloader:
350
- n_batches += 1
351
- n_elements += len(images)
352
- assert len(images) == len(texts)
353
- return n_elements, n_batches
354
-
355
-
356
- def filter_no_caption(sample):
357
- return "txt" in sample
358
-
359
-
360
- def log_and_continue(exn):
361
- """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
362
- logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.")
363
- return True
364
-
365
-
366
- _SHARD_SHUFFLE_SIZE = 2000
367
- _SHARD_SHUFFLE_INITIAL = 500
368
- _SAMPLE_SHUFFLE_SIZE = 5000
369
- _SAMPLE_SHUFFLE_INITIAL = 1000
370
-
371
-
372
- # def sample_prop(sizefile, inputs, proportion, is_local=True):
373
- # """
374
- # Sample a proportion of the data.
375
- # """
376
- # file_path_dict = {
377
- # os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0]
378
- # for i in range(len(inputs))
379
- # }
380
- # sampled_filepath_dict = {}
381
- # sampled_size_dict = {}
382
- # if not is_local:
383
- # if os.path.exists("sizes.json"):
384
- # os.remove("sizes.json")
385
- # wget.download(sizefile, "sizes.json")
386
- # sizefile = "sizes.json"
387
- # with open(sizefile, "r", encoding="UTF-8") as f:
388
- # load_dict = json.load(f)
389
- # L = int(len(file_path_dict) * proportion)
390
- # subkeys = random.sample(file_path_dict.keys(), L)
391
- # for k in subkeys:
392
- # sampled_size_dict[k] = load_dict[k]
393
- # sampled_filepath_dict[k] = file_path_dict[k]
394
- # return (
395
- # sum(sampled_size_dict.values()),
396
- # L,
397
- # [os.path.join(v, k) for k, v in sampled_filepath_dict.items()],
398
- # sampled_size_dict,
399
- # )
400
-
401
-
402
- def get_mel(audio_data, audio_cfg):
403
- # mel shape: (n_mels, T)
404
- mel = torchaudio.transforms.MelSpectrogram(
405
- sample_rate=audio_cfg["sample_rate"],
406
- n_fft=audio_cfg["window_size"],
407
- win_length=audio_cfg["window_size"],
408
- hop_length=audio_cfg["hop_size"],
409
- center=True,
410
- pad_mode="reflect",
411
- power=2.0,
412
- norm=None,
413
- onesided=True,
414
- n_mels=64,
415
- f_min=audio_cfg["fmin"],
416
- f_max=audio_cfg["fmax"],
417
- ).to(audio_data.device)
418
- mel = mel(audio_data)
419
- # we use log mel spectrogram as input
420
- mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel)
421
- return mel.T # (T, n_mels)
422
-
423
-
424
- def get_audio_features(
425
- audio_data, mel, max_len, data_truncating, data_filling, audio_cfg
426
- ):
427
- """
428
- Calculate and add audio features to sample.
429
- Sample: a dict containing all the data of current sample.
430
- audio_data: a tensor of shape (T) containing audio data.
431
- max_len: the maximum length of audio data.
432
- data_truncating: the method of truncating data.
433
- data_filling: the method of filling data.
434
- audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg'].
435
- """
436
- sample = {}
437
-
438
- # assert audio_data.size(-1) <= max_len, str(audio_data.size())
439
-
440
- # split to three parts
441
- chunk_frames = (
442
- max_len // audio_cfg["hop_size"] + 1
443
- ) # the +1 related to how the spectrogram is computed
444
- mel = mel[:chunk_frames]
445
-
446
- audio_data = audio_data[..., :max_len]
447
- sample["mel_fusion"] = mel
448
- longer = torch.tensor([True])
449
-
450
- sample["longer"] = longer
451
- sample["waveform"] = audio_data
452
-
453
- return sample
454
-
455
-
456
- def preprocess(
457
- sample,
458
- audio_ext,
459
- text_ext,
460
- max_len,
461
- audio_cfg,
462
- class_index_dict=None,
463
- data_filling="pad",
464
- data_truncating="rand_trunc",
465
- text_augment_selection=None,
466
- ):
467
- """
468
- Preprocess a single sample for wdsdataloader.
469
- """
470
- audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
471
- audio_data = int16_to_float32(float32_to_int16(audio_data))
472
- audio_data = torch.tensor(audio_data).float()
473
-
474
- # TODO: (yusong) to be include in the future
475
- # # if torchaudio not installed, use soundfile to load audio
476
- # if torchaudio is None:
477
- # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext]))
478
- # audio_data = torch.tensor(audio_data).float()
479
- # else:
480
- # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py
481
- # with tempfile.TemporaryDirectory() as dirname:
482
- # os.makedirs(dirname, exist_ok=True)
483
- # fname = os.path.join(dirname, f"file.flac")
484
- # with open(fname, "wb") as stream:
485
- # stream.write(sample[audio_ext])
486
- # audio_data, orig_sr = torchaudio.load(fname)
487
- # audio_data = audio_data[0, :].float()
488
-
489
- sample = get_audio_features(
490
- sample, audio_data, max_len, data_truncating, data_filling, audio_cfg
491
- )
492
- del sample[audio_ext]
493
-
494
- try:
495
- json_dict_raw = json.loads(sample[text_ext].decode("utf-8"))
496
- except:
497
- print("sample[__url__]:", sample["__url__"])
498
-
499
- # For selecting augmented text from dataset
500
- if text_augment_selection is None or text_augment_selection == "none":
501
- texts = json_dict_raw["text"]
502
- elif text_augment_selection == "all":
503
- if "text_augment_all" in json_dict_raw.keys():
504
- texts = json_dict_raw["text_augment_all"]
505
- else:
506
- texts = json_dict_raw["text"]
507
- elif text_augment_selection == "augment_only":
508
- if "text_augment_all" in json_dict_raw.keys():
509
- if json_dict_raw["text_augment_t5"] is None:
510
- texts = json_dict_raw["text"]
511
- else:
512
- texts = json_dict_raw["text_augment_t5"]
513
- else:
514
- texts = json_dict_raw["text"]
515
- else:
516
- raise NotImplementedError(
517
- f"text_augment_selection {text_augment_selection} not implemented"
518
- )
519
- sample["full_text"] = texts
520
-
521
- if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1:
522
- texts = random.choice(texts)
523
- sample["raw_text"] = texts
524
- sample["text"] = tokenizer(texts) # text shape: [num_token]
525
- if class_index_dict is not None:
526
- # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing
527
- # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array
528
- # key, val = class_index_dict
529
- # key = key[:].split('\n')
530
- # _dict = {k: v for k, v in zip(key, val)}
531
- sample["class_label"] = np.zeros(len(class_index_dict.keys()))
532
- for x in json_dict_raw["tag"]:
533
- sample["class_label"][class_index_dict[x]] = 1
534
- sample["class_label"] = torch.tensor(sample["class_label"]).float()
535
- del sample[text_ext]
536
- sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext
537
- sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext
538
- sample["audio_orig_sr"] = orig_sr
539
- return sample
540
-
541
-
542
- def collate_fn(batch):
543
- """
544
- Collate function for wdsdataloader.
545
- batch: a list of dict, each dict is a sample
546
- """
547
- # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend.
548
- batch_dict = {}
549
- for k in batch[0].keys():
550
- if isinstance(batch[0][k], dict): # dealwith bert tokenizer output
551
- batch_dict[k] = {}
552
- for kk in batch[0][k].keys():
553
- tmp = []
554
- for i in range(len(batch)):
555
- tmp.append(batch[i][k][kk])
556
- batch_dict[k][kk] = torch.vstack(tmp)
557
- elif isinstance(batch[0][k], torch.Tensor):
558
- batch_dict[k] = torch.stack([sample[k] for sample in batch])
559
- elif isinstance(batch[0][k], np.ndarray):
560
- batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch]))
561
- else:
562
- batch_dict[k] = [sample[k] for sample in batch]
563
- return batch_dict
564
-
565
-
566
- # def get_wds_dataset(
567
- # args,
568
- # model_cfg,
569
- # is_train,
570
- # audio_ext="flac",
571
- # text_ext="json",
572
- # max_len=480000,
573
- # proportion=1.0,
574
- # sizefilepath_=None,
575
- # is_local=None,
576
- # ):
577
- # """
578
- # Get a dataset for wdsdataloader.
579
- # """
580
- # if is_local is None and (not args.remotedata is None):
581
- # is_local = not args.remotedata
582
-
583
- # input_shards = args.train_data if is_train else args.val_data
584
- # assert input_shards is not None
585
-
586
- # if not sizefilepath_ is None:
587
- # sizefilepath = sizefilepath_
588
- # else:
589
- # sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json")
590
-
591
- # if proportion != 1.0:
592
- # num_samples, num_shards, input_shards, _ = sample_prop(
593
- # sizefilepath, input_shards, proportion, is_local=is_local
594
- # )
595
- # else:
596
- # num_samples, num_shards = get_dataset_size(
597
- # input_shards, sizefilepath_=sizefilepath_, is_local=is_local
598
- # )
599
-
600
- # if not num_samples:
601
- # if is_train:
602
- # num_samples = args.train_num_samples
603
- # if not num_samples:
604
- # raise RuntimeError(
605
- # "Currently, number of dataset samples must be specified for training dataset. "
606
- # "Please specify via `--train-num-samples` if no dataset length info present."
607
- # )
608
- # else:
609
- # num_samples = (
610
- # args.val_num_samples or 0
611
- # ) # eval will just exhaust the iterator if not specified
612
-
613
- # pipeline = [wds.SimpleShardList(input_shards)]
614
- # # at this point we have an iterator over all the shards
615
- # # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node
616
- # if is_train or args.parallel_eval:
617
- # pipeline.extend(
618
- # [
619
- # wds.detshuffle(
620
- # bufsize=_SHARD_SHUFFLE_SIZE,
621
- # initial=_SHARD_SHUFFLE_INITIAL,
622
- # seed=args.seed,
623
- # ),
624
- # wds.split_by_node,
625
- # wds.split_by_worker,
626
- # # at this point, we have an iterator over the shards assigned to each worker at each node
627
- # wds.tarfile_to_samples(handler=log_and_continue),
628
- # wds.shuffle(
629
- # bufsize=_SAMPLE_SHUFFLE_SIZE,
630
- # initial=_SAMPLE_SHUFFLE_INITIAL,
631
- # rng=random.Random(args.seed),
632
- # ),
633
- # # wds.repeatedly, # FIXME determine if this is beneficial
634
- # ]
635
- # )
636
- # else:
637
- # pipeline.extend(
638
- # [
639
- # wds.split_by_worker,
640
- # # at this point, we have an iterator over the shards assigned to each worker
641
- # wds.tarfile_to_samples(handler=log_and_continue),
642
- # ]
643
- # )
644
- # pipeline.append(
645
- # wds.map(
646
- # partial(
647
- # preprocess,
648
- # audio_ext=audio_ext,
649
- # text_ext=text_ext,
650
- # max_len=max_len,
651
- # audio_cfg=model_cfg["audio_cfg"],
652
- # class_index_dict=copy.deepcopy(args.class_index_dict),
653
- # data_filling=args.data_filling,
654
- # data_truncating=args.data_truncating,
655
- # text_augment_selection=args.text_augment_selection,
656
- # )
657
- # ),
658
- # )
659
-
660
- # pipeline.append(
661
- # wds.batched(
662
- # args.batch_size,
663
- # partial=not (is_train or args.parallel_eval),
664
- # collation_fn=collate_fn,
665
- # )
666
- # )
667
-
668
- # dataset = wds.DataPipeline(*pipeline)
669
- # if is_train or args.parallel_eval:
670
- # # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples.
671
- # # (yusong): See comments below.
672
- # # roll over and repeat a few samples to get same number of full batches on each node
673
- # global_batch_size = args.batch_size * args.world_size
674
- # num_batches = math.ceil(num_samples / global_batch_size)
675
- # num_workers = max(1, args.workers)
676
- # num_worker_batches = math.ceil(
677
- # num_batches / num_workers
678
- # ) # per dataloader worker
679
- # num_batches = num_worker_batches * num_workers
680
- # num_samples = num_batches * global_batch_size
681
- # dataset = dataset.with_epoch(
682
- # num_worker_batches
683
- # ) # each worker is iterating over this
684
- # else:
685
- # # last batches are partial, eval is done on single (master) node
686
- # num_batches = math.ceil(num_samples / args.batch_size)
687
-
688
- # kwargs = {}
689
- # if args.horovod: # multi-node training on summit
690
- # kwargs["multiprocessing_context"] = "forkserver"
691
-
692
- # dataloader = wds.WebLoader(
693
- # dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs
694
- # )
695
-
696
- # # FIXME not clear which approach is better, with_epoch before vs after dataloader?
697
- # # hoping to resolve via https://github.com/webdataset/webdataset/issues/169
698
- # # if is_train:
699
- # # # roll over and repeat a few samples to get same number of full batches on each node
700
- # # global_batch_size = args.batch_size * args.world_size
701
- # # num_batches = math.ceil(num_samples / global_batch_size)
702
- # # num_workers = max(1, args.workers)
703
- # # num_batches = math.ceil(num_batches / num_workers) * num_workers
704
- # # num_samples = num_batches * global_batch_size
705
- # # dataloader = dataloader.with_epoch(num_batches)
706
- # # else:
707
- # # # last batches are partial, eval is done on single (master) node
708
- # # num_batches = math.ceil(num_samples / args.batch_size)
709
-
710
- # # add meta-data to dataloader instance for convenience
711
- # dataloader.num_batches = num_batches
712
- # dataloader.num_samples = num_samples
713
-
714
- # return DataInfo(dataloader, None)
715
-
716
-
717
- def wds_batch_list2dict(
718
- batch,
719
- keys=[
720
- "__url__",
721
- "__key__",
722
- "waveform",
723
- "text",
724
- "raw_text",
725
- "audio_name",
726
- "text_name",
727
- "audio_orig_sr",
728
- ],
729
- ):
730
- """
731
- Return a dictionary of the batch, with keys as the names of the fields.
732
- """
733
- assert len(keys) == len(
734
- batch
735
- ), "batch must have same number of keys as keys argument"
736
- return {keys[i]: batch[i] for i in range(len(batch))}
737
-
738
-
739
- def get_csv_dataset(args, preprocess_fn, is_train):
740
- input_filename = args.train_data if is_train else args.val_data
741
- assert input_filename
742
- dataset = CsvDataset(
743
- input_filename,
744
- preprocess_fn,
745
- img_key=args.csv_img_key,
746
- caption_key=args.csv_caption_key,
747
- sep=args.csv_separator,
748
- )
749
- num_samples = len(dataset)
750
- sampler = DistributedSampler(dataset) if args.distributed and is_train else None
751
- shuffle = is_train and sampler is None
752
-
753
- dataloader = DataLoader(
754
- dataset,
755
- batch_size=args.batch_size,
756
- shuffle=shuffle,
757
- num_workers=args.workers,
758
- pin_memory=True,
759
- sampler=sampler,
760
- drop_last=is_train,
761
- )
762
- dataloader.num_samples = num_samples
763
- dataloader.num_batches = len(dataloader)
764
-
765
- return DataInfo(dataloader, sampler)
766
-
767
-
768
- def get_toy_dataset(args, model_cfg, is_train):
769
- index_path = args.train_data if is_train else args.val_data
770
- ipc_path = args.train_ipc if is_train else args.val_ipc
771
- assert index_path and ipc_path
772
- eval_mode = not is_train
773
- dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode)
774
-
775
- num_samples = len(dataset)
776
- sampler = (
777
- DistributedSampler(dataset, shuffle=False)
778
- if args.distributed and is_train
779
- else None
780
- )
781
-
782
- dataloader = DataLoader(
783
- dataset,
784
- batch_size=args.batch_size,
785
- shuffle=False,
786
- num_workers=args.workers,
787
- sampler=sampler,
788
- drop_last=is_train,
789
- )
790
- dataloader.num_samples = num_samples
791
- dataloader.num_batches = len(dataloader)
792
-
793
- return DataInfo(dataloader, sampler)
794
-
795
-
796
- def get_dataset_fn(data_path, dataset_type):
797
- if dataset_type == "webdataset":
798
- return get_wds_dataset
799
- elif dataset_type == "csv":
800
- return get_csv_dataset
801
- elif dataset_type == "auto":
802
- ext = data_path.split(".")[-1]
803
- if ext in ["csv", "tsv"]:
804
- return get_csv_dataset
805
- elif ext in ["tar"]:
806
- return get_wds_dataset
807
- else:
808
- raise ValueError(
809
- f"Tried to figure out dataset type, but failed for extention {ext}."
810
- )
811
- elif dataset_type == "toy":
812
- return get_toy_dataset
813
- else:
814
- raise ValueError(f"Unsupported dataset type: {dataset_type}")
815
-
816
-
817
- def get_data(args, model_cfg):
818
- data = {}
819
-
820
- args.class_index_dict = load_class_label(args.class_label_path)
821
-
822
- if args.datasetinfos is None:
823
- args.datasetinfos = ["train", "unbalanced_train", "balanced_train"]
824
- if args.dataset_type == "webdataset":
825
- args.train_data = get_tar_path_from_dataset_name(
826
- args.datasetnames,
827
- args.datasetinfos,
828
- islocal=not args.remotedata,
829
- proportion=args.dataset_proportion,
830
- dataset_path=args.datasetpath,
831
- full_dataset=args.full_train_dataset,
832
- )
833
-
834
- if args.full_train_dataset is None:
835
- args.full_train_dataset = []
836
- if args.exclude_eval_dataset is None:
837
- args.exclude_eval_dataset = []
838
- excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset
839
-
840
- val_dataset_names = (
841
- [n for n in args.datasetnames if n not in excluded_eval_datasets]
842
- if excluded_eval_datasets
843
- else args.datasetnames
844
- )
845
- args.val_dataset_names = val_dataset_names
846
- args.val_data = get_tar_path_from_dataset_name(
847
- val_dataset_names,
848
- ["valid", "test", "eval"],
849
- islocal=not args.remotedata,
850
- proportion=1,
851
- dataset_path=args.datasetpath,
852
- full_dataset=None,
853
- )
854
-
855
- if args.train_data:
856
- data["train"] = get_dataset_fn(args.train_data, args.dataset_type)(
857
- args, model_cfg, is_train=True
858
- )
859
-
860
- if args.val_data:
861
- data["val"] = get_dataset_fn(args.val_data, args.dataset_type)(
862
- args, model_cfg, is_train=False
863
- )
864
-
865
- return data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/clap/training/params.py DELETED
@@ -1,563 +0,0 @@
1
- import argparse
2
-
3
-
4
- def get_default_params(model_name):
5
- # Params from paper (https://arxiv.org/pdf/2103.00020.pdf)
6
- model_name = model_name.lower()
7
- if "vit" in model_name:
8
- return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6}
9
- else:
10
- return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8}
11
-
12
-
13
- def parse_args():
14
- parser = argparse.ArgumentParser()
15
- parser.add_argument(
16
- "--train-data",
17
- type=str,
18
- default=None,
19
- help="Path to h5 filewith training data",
20
- )
21
- parser.add_argument(
22
- "--val-data",
23
- type=str,
24
- default=None,
25
- help="Path to h5 file with validation data",
26
- )
27
- parser.add_argument(
28
- "--freeze-text",
29
- default=False,
30
- action="store_true",
31
- help="if you need to freeze the text encoder, make this True",
32
- )
33
- parser.add_argument(
34
- "--freeze-text-after",
35
- type=int,
36
- default=-1,
37
- help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it",
38
- )
39
- parser.add_argument(
40
- "--train-ipc",
41
- type=str,
42
- default=None,
43
- help="Path to npy file of the number of instance per class in training data",
44
- )
45
- parser.add_argument(
46
- "--val-ipc",
47
- type=str,
48
- default=None,
49
- help="Path to npy file of the number of instance per class in validation data",
50
- )
51
- parser.add_argument(
52
- "--train-num-samples",
53
- type=int,
54
- default=None,
55
- help="Number of samples in dataset. Required for webdataset if not available in info file.",
56
- )
57
- parser.add_argument(
58
- "--val-num-samples",
59
- type=int,
60
- default=None,
61
- help="Number of samples in dataset. Useful for webdataset if not available in info file.",
62
- )
63
- parser.add_argument(
64
- "--dataset-type",
65
- choices=["webdataset", "csv", "auto", "toy"],
66
- default="auto",
67
- help="Which type of dataset to process.",
68
- )
69
- parser.add_argument(
70
- "--csv-separator",
71
- type=str,
72
- default="\t",
73
- help="For csv-like datasets, which separator to use.",
74
- )
75
- parser.add_argument(
76
- "--csv-img-key",
77
- type=str,
78
- default="filepath",
79
- help="For csv-like datasets, the name of the key for the image paths.",
80
- )
81
- parser.add_argument(
82
- "--csv-caption-key",
83
- type=str,
84
- default="title",
85
- help="For csv-like datasets, the name of the key for the captions.",
86
- )
87
- parser.add_argument(
88
- "--imagenet-val",
89
- type=str,
90
- default=None,
91
- help="Path to imagenet val set for conducting zero shot evaluation.",
92
- )
93
- parser.add_argument(
94
- "--imagenet-v2",
95
- type=str,
96
- default=None,
97
- help="Path to imagenet v2 for conducting zero shot evaluation.",
98
- )
99
- parser.add_argument(
100
- "--datasetnames",
101
- nargs="+",
102
- default=None,
103
- help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects",
104
- )
105
- parser.add_argument(
106
- "--full-train-dataset",
107
- nargs="+",
108
- default=None,
109
- help="Which dataset will be trained with all the subsets. (train+test)",
110
- )
111
- parser.add_argument(
112
- "--exclude-eval-dataset",
113
- nargs="+",
114
- default=None,
115
- help="Which dataset will be excluded with evaluation",
116
- )
117
- parser.add_argument(
118
- "--datasetinfos",
119
- nargs="+",
120
- default=None,
121
- help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval",
122
- )
123
- parser.add_argument(
124
- "--dataset-proportion",
125
- type=float,
126
- default=1.0,
127
- help="How much proportion of dataset we want to train.",
128
- )
129
- parser.add_argument(
130
- "--remotedata",
131
- default=False,
132
- action="store_true",
133
- help="if the dataset is remote, set this flag",
134
- )
135
- parser.add_argument(
136
- "--class-label-path",
137
- type=str,
138
- default=None,
139
- help="The path of the class label pickle or csv.",
140
- )
141
- parser.add_argument(
142
- "--datasetpath",
143
- type=str,
144
- default="/mnt/audio_clip/webdataset_tar",
145
- help="The path to the dataset",
146
- )
147
- parser.add_argument(
148
- "--logs",
149
- type=str,
150
- default="./logs/",
151
- help="Where to store tensorboard logs. Use None to avoid storing logs.",
152
- )
153
- parser.add_argument(
154
- "--log-local",
155
- action="store_true",
156
- default=False,
157
- help="log files on local master, otherwise global master only.",
158
- )
159
- parser.add_argument(
160
- "--name",
161
- type=str,
162
- default=None,
163
- help="Optional identifier for the experiment when storing logs. Otherwise use current time.",
164
- )
165
- parser.add_argument(
166
- "--workers", type=int, default=1, help="Number of workers per GPU."
167
- )
168
- parser.add_argument(
169
- "--batch-size", type=int, default=64, help="Batch size per GPU."
170
- )
171
- parser.add_argument(
172
- "--epochs", type=int, default=32, help="Number of epochs to train for."
173
- )
174
- parser.add_argument("--lr", type=float, default=None, help="Learning rate.")
175
- parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.")
176
- parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.")
177
- parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.")
178
- parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.")
179
- parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.")
180
-
181
- parser.add_argument(
182
- "--split-opt",
183
- action="store_true",
184
- default=False,
185
- help="Use this flag to skip the learning rate decay.",
186
- )
187
- parser.add_argument(
188
- "--lr-pretrained", type=float, default=None, help="Learning rate for text."
189
- )
190
- parser.add_argument(
191
- "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text."
192
- )
193
- parser.add_argument(
194
- "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text."
195
- )
196
- parser.add_argument(
197
- "--eps-pretrained", type=float, default=None, help="Adam epsilon for text."
198
- )
199
- parser.add_argument(
200
- "--wd-pretrained", type=float, default=0.2, help="Weight decay for text."
201
- )
202
- parser.add_argument(
203
- "--momentum-pretrained", type=float, default=0.9, help="Momentum for text."
204
- )
205
- parser.add_argument(
206
- "--lr-new", type=float, default=None, help="Learning rate for audio."
207
- )
208
- parser.add_argument(
209
- "--beta1-new", type=float, default=None, help="Adam beta 1 for audio."
210
- )
211
- parser.add_argument(
212
- "--beta2-new", type=float, default=None, help="Adam beta 2 for audio."
213
- )
214
- parser.add_argument(
215
- "--eps-new", type=float, default=None, help="Adam epsilon for audio."
216
- )
217
- parser.add_argument(
218
- "--wd-new", type=float, default=0.2, help="Weight decay for audio."
219
- )
220
- parser.add_argument(
221
- "--momentum-new", type=float, default=0.9, help="Momentum for audio."
222
- )
223
- parser.add_argument(
224
- "--warmup", type=int, default=10000, help="Number of steps to warmup for."
225
- )
226
- parser.add_argument(
227
- "--use-bn-sync",
228
- default=False,
229
- action="store_true",
230
- help="Whether to use batch norm sync.",
231
- )
232
- parser.add_argument(
233
- "--skip-scheduler",
234
- action="store_true",
235
- default=False,
236
- help="Use this flag to skip the learning rate decay.",
237
- )
238
- parser.add_argument(
239
- "--save-frequency", type=int, default=1, help="How often to save checkpoints."
240
- )
241
- parser.add_argument(
242
- "--save-top-performance",
243
- type=int,
244
- default=0,
245
- help="Save the top x performance weights if the value >0",
246
- )
247
- parser.add_argument(
248
- "--save-most-recent",
249
- action="store_true",
250
- default=False,
251
- help="Always save the most recent model trained to epoch_latest.pt.",
252
- )
253
- parser.add_argument(
254
- "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot."
255
- )
256
- parser.add_argument(
257
- "--val-frequency",
258
- type=int,
259
- default=1,
260
- help="How often to run evaluation with val data.",
261
- )
262
- parser.add_argument(
263
- "--resume",
264
- default=None,
265
- type=str,
266
- help="path to latest checkpoint (default: none)",
267
- )
268
- parser.add_argument(
269
- "--precision",
270
- choices=["amp", "fp16", "fp32"],
271
- default="amp",
272
- help="Floating point precision.",
273
- )
274
- parser.add_argument(
275
- "--amodel",
276
- type=str,
277
- default="RN50",
278
- help="Name of the audio backbone to use.",
279
- )
280
- parser.add_argument(
281
- "--tmodel",
282
- type=str,
283
- default="transformer",
284
- help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]",
285
- )
286
- parser.add_argument(
287
- "--pretrained-audio",
288
- default="",
289
- type=str,
290
- help="Use a pretrained audio model weights for the audio encoder of CLAP",
291
- )
292
- parser.add_argument(
293
- "--pretrained-text",
294
- default="",
295
- type=str,
296
- help="Use a pretrained text model weights for the text encoder of CLAP",
297
- )
298
- parser.add_argument(
299
- "--pretrained",
300
- default="",
301
- type=str,
302
- help="Use a pretrained CLIP model weights with the specified tag or file path.",
303
- )
304
- parser.add_argument(
305
- "--pretrained-image",
306
- default=False,
307
- action="store_true",
308
- help="Load imagenet pretrained weights for image tower backbone if available.",
309
- )
310
- parser.add_argument(
311
- "--lock-image",
312
- default=False,
313
- action="store_true",
314
- help="Lock full image tower by disabling gradients.",
315
- )
316
- parser.add_argument(
317
- "--lock-image-unlocked-groups",
318
- type=int,
319
- default=0,
320
- help="Leave last n image tower layer groups unlocked.",
321
- )
322
- parser.add_argument(
323
- "--lock-image-freeze-bn-stats",
324
- default=False,
325
- action="store_true",
326
- help="Freeze BatchNorm running stats in image tower for any locked layers.",
327
- )
328
- parser.add_argument(
329
- "--local-loss",
330
- default=False,
331
- action="store_true",
332
- help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)",
333
- )
334
- parser.add_argument(
335
- "--gather-with-grad",
336
- default=False,
337
- action="store_true",
338
- help="enable full distributed gradient for feature gather",
339
- )
340
- parser.add_argument(
341
- "--force-quick-gelu",
342
- default=False,
343
- action="store_true",
344
- help="Force use of QuickGELU activation for non-OpenAI transformer models.",
345
- )
346
- parser.add_argument(
347
- "--torchscript",
348
- default=False,
349
- action="store_true",
350
- help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'",
351
- )
352
- parser.add_argument(
353
- "--trace",
354
- default=False,
355
- action="store_true",
356
- help="torch.jit.trace the model for inference / eval only",
357
- )
358
- # arguments for distributed training
359
- parser.add_argument(
360
- "--dist-url",
361
- default="env://",
362
- type=str,
363
- help="url used to set up distributed training",
364
- )
365
- parser.add_argument(
366
- "--dist-backend", default="nccl", type=str, help="distributed backend"
367
- )
368
- parser.add_argument(
369
- "--report-to",
370
- default="",
371
- type=str,
372
- help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']",
373
- )
374
- parser.add_argument(
375
- "--wandb-notes", default="", type=str, help="Notes if logging with wandb"
376
- )
377
- parser.add_argument(
378
- "--C", type=float, default=3.16, help="inverse regularizer for logistic reg."
379
- )
380
- parser.add_argument(
381
- "--debug",
382
- default=False,
383
- action="store_true",
384
- help="If true, more information is logged.",
385
- )
386
- parser.add_argument(
387
- "--copy-codebase",
388
- default=False,
389
- action="store_true",
390
- help="If true, we copy the entire base on the log diretory, and execute from there.",
391
- )
392
- parser.add_argument(
393
- "--horovod",
394
- default=False,
395
- action="store_true",
396
- help="Use horovod for distributed training.",
397
- )
398
- parser.add_argument(
399
- "--ddp-static-graph",
400
- default=False,
401
- action="store_true",
402
- help="Enable static graph optimization for DDP in PyTorch >= 1.11.",
403
- )
404
- parser.add_argument(
405
- "--no-set-device-rank",
406
- default=False,
407
- action="store_true",
408
- help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).",
409
- )
410
- parser.add_argument("--seed", type=int, default=4242, help="Default random seed.")
411
-
412
- parser.add_argument(
413
- "--top-k-checkpoint-select-dataset",
414
- type=str,
415
- default="all",
416
- help="The dataset of selecting top-k checkpoint.",
417
- )
418
-
419
- # @R10, @R@5, @R1, mAP@10
420
- parser.add_argument(
421
- "--top-k-checkpoint-select-metric",
422
- type=str,
423
- default="_R@10",
424
- help="The metric for selecting top-k checkpoint.",
425
- )
426
- parser.add_argument(
427
- "--openai-model-cache-dir",
428
- type=str,
429
- default="~/.cache/clip",
430
- help="Directory to download OpenAI models.",
431
- )
432
- parser.add_argument(
433
- "--optimizer",
434
- type=str,
435
- default="adamw",
436
- help="can be AdamW or SGD",
437
- )
438
- parser.add_argument(
439
- "--parallel-eval",
440
- default=False,
441
- action="store_true",
442
- help="Eval in parallel (multi-GPU, multi-node).",
443
- )
444
-
445
- parser.add_argument(
446
- "--no-eval",
447
- default=False,
448
- action="store_true",
449
- help="Training without evaluation.",
450
- )
451
-
452
- parser.add_argument(
453
- "--lp-mlp",
454
- default=False,
455
- action="store_true",
456
- help="Linear Probe using MLP layer or not.",
457
- )
458
-
459
- parser.add_argument(
460
- "--lp-freeze",
461
- default=False,
462
- action="store_true",
463
- help="Linear Probe using Freeze CLAP or not",
464
- )
465
-
466
- parser.add_argument(
467
- "--lp-act",
468
- default="None",
469
- type=str,
470
- help="Options are ['relu','elu','prelu','softmax','sigmoid']",
471
- )
472
-
473
- parser.add_argument(
474
- "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe."
475
- )
476
-
477
- parser.add_argument(
478
- "--lp-metrics",
479
- type=str,
480
- default="map,mauc,acc",
481
- help="Metrics of Linear Probe.",
482
- )
483
-
484
- parser.add_argument(
485
- "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe"
486
- )
487
- parser.add_argument(
488
- "--kappa",
489
- type=float,
490
- default=0,
491
- help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss",
492
- )
493
-
494
- parser.add_argument(
495
- "--data-filling",
496
- type=str,
497
- default="pad",
498
- help="type of data filling when the audio length is shorter than the max length."
499
- "Can be one of the following: repeat, repeatpad, pad",
500
- )
501
- parser.add_argument(
502
- "--data-truncating",
503
- type=str,
504
- default="rand_trunc",
505
- help="type of data truncation when the audio length is longer than the max length."
506
- "Can be one of the following: rand_trunc, fusion",
507
- )
508
-
509
- parser.add_argument(
510
- "--clap-mlploss",
511
- default=False,
512
- action="store_true",
513
- help="Using MLP loss for CLAP model or not",
514
- )
515
-
516
- parser.add_argument(
517
- "--wandb-id",
518
- type=str,
519
- default=None,
520
- help="the id of wandb experiment to restore.",
521
- )
522
-
523
- parser.add_argument(
524
- "--sleep", type=float, default=0, help="sleep n seconds before start training"
525
- )
526
-
527
- # variable length processing
528
- parser.add_argument(
529
- "--enable-fusion",
530
- default=False,
531
- action="store_true",
532
- help="Enable feature funsion for variable-length data",
533
- )
534
-
535
- parser.add_argument(
536
- "--fusion-type",
537
- type=str,
538
- default="None",
539
- help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']",
540
- )
541
-
542
- parser.add_argument(
543
- "--mixup",
544
- default=False,
545
- action="store_true",
546
- help="Enable mixup in finetuning training.",
547
- )
548
- parser.add_argument(
549
- "--text-augment-selection",
550
- type=str,
551
- default=None,
552
- help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']",
553
- )
554
-
555
- args = parser.parse_args()
556
-
557
- # If some params are not passed, we use the default values based on model name.
558
- default_params = get_default_params(args.amodel)
559
- for name, val in default_params.items():
560
- if getattr(args, name) is None:
561
- setattr(args, name, val)
562
-
563
- return args
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/hifigan/LICENSE DELETED
@@ -1,21 +0,0 @@
1
- MIT License
2
-
3
- Copyright (c) 2020 Jungil Kong
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining a copy
6
- of this software and associated documentation files (the "Software"), to deal
7
- in the Software without restriction, including without limitation the rights
8
- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
- copies of the Software, and to permit persons to whom the Software is
10
- furnished to do so, subject to the following conditions:
11
-
12
- The above copyright notice and this permission notice shall be included in all
13
- copies or substantial portions of the Software.
14
-
15
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
- SOFTWARE.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/hifigan/__init__.py DELETED
@@ -1,8 +0,0 @@
1
- from .models_v2 import Generator
2
- from .models import Generator as Generator_old
3
-
4
-
5
- class AttrDict(dict):
6
- def __init__(self, *args, **kwargs):
7
- super(AttrDict, self).__init__(*args, **kwargs)
8
- self.__dict__ = self
 
 
 
 
 
 
 
 
 
audioldm2/hifigan/models.py DELETED
@@ -1,174 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch.nn import Conv1d, ConvTranspose1d
5
- from torch.nn.utils import weight_norm, remove_weight_norm
6
-
7
- LRELU_SLOPE = 0.1
8
-
9
-
10
- def init_weights(m, mean=0.0, std=0.01):
11
- classname = m.__class__.__name__
12
- if classname.find("Conv") != -1:
13
- m.weight.data.normal_(mean, std)
14
-
15
-
16
- def get_padding(kernel_size, dilation=1):
17
- return int((kernel_size * dilation - dilation) / 2)
18
-
19
-
20
- class ResBlock(torch.nn.Module):
21
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
- super(ResBlock, self).__init__()
23
- self.h = h
24
- self.convs1 = nn.ModuleList(
25
- [
26
- weight_norm(
27
- Conv1d(
28
- channels,
29
- channels,
30
- kernel_size,
31
- 1,
32
- dilation=dilation[0],
33
- padding=get_padding(kernel_size, dilation[0]),
34
- )
35
- ),
36
- weight_norm(
37
- Conv1d(
38
- channels,
39
- channels,
40
- kernel_size,
41
- 1,
42
- dilation=dilation[1],
43
- padding=get_padding(kernel_size, dilation[1]),
44
- )
45
- ),
46
- weight_norm(
47
- Conv1d(
48
- channels,
49
- channels,
50
- kernel_size,
51
- 1,
52
- dilation=dilation[2],
53
- padding=get_padding(kernel_size, dilation[2]),
54
- )
55
- ),
56
- ]
57
- )
58
- self.convs1.apply(init_weights)
59
-
60
- self.convs2 = nn.ModuleList(
61
- [
62
- weight_norm(
63
- Conv1d(
64
- channels,
65
- channels,
66
- kernel_size,
67
- 1,
68
- dilation=1,
69
- padding=get_padding(kernel_size, 1),
70
- )
71
- ),
72
- weight_norm(
73
- Conv1d(
74
- channels,
75
- channels,
76
- kernel_size,
77
- 1,
78
- dilation=1,
79
- padding=get_padding(kernel_size, 1),
80
- )
81
- ),
82
- weight_norm(
83
- Conv1d(
84
- channels,
85
- channels,
86
- kernel_size,
87
- 1,
88
- dilation=1,
89
- padding=get_padding(kernel_size, 1),
90
- )
91
- ),
92
- ]
93
- )
94
- self.convs2.apply(init_weights)
95
-
96
- def forward(self, x):
97
- for c1, c2 in zip(self.convs1, self.convs2):
98
- xt = F.leaky_relu(x, LRELU_SLOPE)
99
- xt = c1(xt)
100
- xt = F.leaky_relu(xt, LRELU_SLOPE)
101
- xt = c2(xt)
102
- x = xt + x
103
- return x
104
-
105
- def remove_weight_norm(self):
106
- for l in self.convs1:
107
- remove_weight_norm(l)
108
- for l in self.convs2:
109
- remove_weight_norm(l)
110
-
111
-
112
- class Generator(torch.nn.Module):
113
- def __init__(self, h):
114
- super(Generator, self).__init__()
115
- self.h = h
116
- self.num_kernels = len(h.resblock_kernel_sizes)
117
- self.num_upsamples = len(h.upsample_rates)
118
- self.conv_pre = weight_norm(
119
- Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
120
- )
121
- resblock = ResBlock
122
-
123
- self.ups = nn.ModuleList()
124
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
125
- self.ups.append(
126
- weight_norm(
127
- ConvTranspose1d(
128
- h.upsample_initial_channel // (2**i),
129
- h.upsample_initial_channel // (2 ** (i + 1)),
130
- k,
131
- u,
132
- padding=(k - u) // 2,
133
- )
134
- )
135
- )
136
-
137
- self.resblocks = nn.ModuleList()
138
- for i in range(len(self.ups)):
139
- ch = h.upsample_initial_channel // (2 ** (i + 1))
140
- for j, (k, d) in enumerate(
141
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
142
- ):
143
- self.resblocks.append(resblock(h, ch, k, d))
144
-
145
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
146
- self.ups.apply(init_weights)
147
- self.conv_post.apply(init_weights)
148
-
149
- def forward(self, x):
150
- x = self.conv_pre(x)
151
- for i in range(self.num_upsamples):
152
- x = F.leaky_relu(x, LRELU_SLOPE)
153
- x = self.ups[i](x)
154
- xs = None
155
- for j in range(self.num_kernels):
156
- if xs is None:
157
- xs = self.resblocks[i * self.num_kernels + j](x)
158
- else:
159
- xs += self.resblocks[i * self.num_kernels + j](x)
160
- x = xs / self.num_kernels
161
- x = F.leaky_relu(x)
162
- x = self.conv_post(x)
163
- x = torch.tanh(x)
164
-
165
- return x
166
-
167
- def remove_weight_norm(self):
168
- # print("Removing weight norm...")
169
- for l in self.ups:
170
- remove_weight_norm(l)
171
- for l in self.resblocks:
172
- l.remove_weight_norm()
173
- remove_weight_norm(self.conv_pre)
174
- remove_weight_norm(self.conv_post)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
audioldm2/hifigan/models_v2.py DELETED
@@ -1,395 +0,0 @@
1
- import torch
2
- import torch.nn.functional as F
3
- import torch.nn as nn
4
- from torch.nn import Conv1d, ConvTranspose1d
5
- from torch.nn.utils import weight_norm, remove_weight_norm
6
-
7
- LRELU_SLOPE = 0.1
8
-
9
-
10
- def init_weights(m, mean=0.0, std=0.01):
11
- classname = m.__class__.__name__
12
- if classname.find("Conv") != -1:
13
- m.weight.data.normal_(mean, std)
14
-
15
-
16
- def get_padding(kernel_size, dilation=1):
17
- return int((kernel_size * dilation - dilation) / 2)
18
-
19
-
20
- class ResBlock1(torch.nn.Module):
21
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
22
- super(ResBlock1, self).__init__()
23
- self.h = h
24
- self.convs1 = nn.ModuleList(
25
- [
26
- weight_norm(
27
- Conv1d(
28
- channels,
29
- channels,
30
- kernel_size,
31
- 1,
32
- dilation=dilation[0],
33
- padding=get_padding(kernel_size, dilation[0]),
34
- )
35
- ),
36
- weight_norm(
37
- Conv1d(
38
- channels,
39
- channels,
40
- kernel_size,
41
- 1,
42
- dilation=dilation[1],
43
- padding=get_padding(kernel_size, dilation[1]),
44
- )
45
- ),
46
- weight_norm(
47
- Conv1d(
48
- channels,
49
- channels,
50
- kernel_size,
51
- 1,
52
- dilation=dilation[2],
53
- padding=get_padding(kernel_size, dilation[2]),
54
- )
55
- ),
56
- ]
57
- )
58
- self.convs1.apply(init_weights)
59
-
60
- self.convs2 = nn.ModuleList(
61
- [
62
- weight_norm(
63
- Conv1d(
64
- channels,
65
- channels,
66
- kernel_size,
67
- 1,
68
- dilation=1,
69
- padding=get_padding(kernel_size, 1),
70
- )
71
- ),
72
- weight_norm(
73
- Conv1d(
74
- channels,
75
- channels,
76
- kernel_size,
77
- 1,
78
- dilation=1,
79
- padding=get_padding(kernel_size, 1),
80
- )
81
- ),
82
- weight_norm(
83
- Conv1d(
84
- channels,
85
- channels,
86
- kernel_size,
87
- 1,
88
- dilation=1,
89
- padding=get_padding(kernel_size, 1),
90
- )
91
- ),
92
- ]
93
- )
94
- self.convs2.apply(init_weights)
95
-
96
- def forward(self, x):
97
- for c1, c2 in zip(self.convs1, self.convs2):
98
- xt = F.leaky_relu(x, LRELU_SLOPE)
99
- xt = c1(xt)
100
- xt = F.leaky_relu(xt, LRELU_SLOPE)
101
- xt = c2(xt)
102
- x = xt + x
103
- return x
104
-
105
- def remove_weight_norm(self):
106
- for l in self.convs1:
107
- remove_weight_norm(l)
108
- for l in self.convs2:
109
- remove_weight_norm(l)
110
-
111
-
112
- class ResBlock2(torch.nn.Module):
113
- def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
114
- super(ResBlock2, self).__init__()
115
- self.h = h
116
- self.convs = nn.ModuleList(
117
- [
118
- weight_norm(
119
- Conv1d(
120
- channels,
121
- channels,
122
- kernel_size,
123
- 1,
124
- dilation=dilation[0],
125
- padding=get_padding(kernel_size, dilation[0]),
126
- )
127
- ),
128
- weight_norm(
129
- Conv1d(
130
- channels,
131
- channels,
132
- kernel_size,
133
- 1,
134
- dilation=dilation[1],
135
- padding=get_padding(kernel_size, dilation[1]),
136
- )
137
- ),
138
- ]
139
- )
140
- self.convs.apply(init_weights)
141
-
142
- def forward(self, x):
143
- for c in self.convs:
144
- xt = F.leaky_relu(x, LRELU_SLOPE)
145
- xt = c(xt)
146
- x = xt + x
147
- return x
148
-
149
- def remove_weight_norm(self):
150
- for l in self.convs:
151
- remove_weight_norm(l)
152
-
153
-
154
- class Generator(torch.nn.Module):
155
- def __init__(self, h):
156
- super(Generator, self).__init__()
157
- self.h = h
158
- self.num_kernels = len(h.resblock_kernel_sizes)
159
- self.num_upsamples = len(h.upsample_rates)
160
- self.conv_pre = weight_norm(
161
- Conv1d(256, h.upsample_initial_channel, 7, 1, padding=3)
162
- )
163
- resblock = ResBlock1 if h.resblock == "1" else ResBlock2
164
-
165
- self.ups = nn.ModuleList()
166
- for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
167
- self.ups.append(
168
- weight_norm(
169
- ConvTranspose1d(
170
- h.upsample_initial_channel // (2**i),
171
- h.upsample_initial_channel // (2 ** (i + 1)),
172
- u * 2,
173
- u,
174
- padding=u // 2 + u % 2,
175
- output_padding=u % 2,
176
- )
177
- )
178
- )
179
-
180
- self.resblocks = nn.ModuleList()
181
- for i in range(len(self.ups)):
182
- ch = h.upsample_initial_channel // (2 ** (i + 1))
183
- for j, (k, d) in enumerate(
184
- zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
185
- ):
186
- self.resblocks.append(resblock(h, ch, k, d))
187
-
188
- self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
189
- self.ups.apply(init_weights)
190
- self.conv_post.apply(init_weights)
191
-
192
- def forward(self, x):
193
- # import ipdb; ipdb.set_trace()
194
- x = self.conv_pre(x)
195
- for i in range(self.num_upsamples):
196
- x = F.leaky_relu(x, LRELU_SLOPE)
197
- x = self.ups[i](x)
198
- xs = None
199
- for j in range(self.num_kernels):
200
- if xs is None:
201
- xs = self.resblocks[i * self.num_kernels + j](x)
202
- else:
203
- xs += self.resblocks[i * self.num_kernels + j](x)
204
- x = xs / self.num_kernels
205
- x = F.leaky_relu(x)
206
- x = self.conv_post(x)
207
- x = torch.tanh(x)
208
-
209
- return x
210
-
211
- def remove_weight_norm(self):
212
- # print('Removing weight norm...')
213
- for l in self.ups:
214
- remove_weight_norm(l)
215
- for l in self.resblocks:
216
- l.remove_weight_norm()
217
- remove_weight_norm(self.conv_pre)
218
- remove_weight_norm(self.conv_post)
219
-
220
-
221
- ##################################################################################################
222
-
223
- # import torch
224
- # import torch.nn as nn
225
- # import torch.nn.functional as F
226
- # from torch.nn import Conv1d, ConvTranspose1d
227
- # from torch.nn.utils import weight_norm, remove_weight_norm
228
-
229
- # LRELU_SLOPE = 0.1
230
-
231
-
232
- # def init_weights(m, mean=0.0, std=0.01):
233
- # classname = m.__class__.__name__
234
- # if classname.find("Conv") != -1:
235
- # m.weight.data.normal_(mean, std)
236
-
237
-
238
- # def get_padding(kernel_size, dilation=1):
239
- # return int((kernel_size * dilation - dilation) / 2)
240
-
241
-
242
- # class ResBlock(torch.nn.Module):
243
- # def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
244
- # super(ResBlock, self).__init__()
245
- # self.h = h
246
- # self.convs1 = nn.ModuleList(
247
- # [
248
- # weight_norm(
249
- # Conv1d(
250
- # channels,
251
- # channels,
252
- # kernel_size,
253
- # 1,
254
- # dilation=dilation[0],
255
- # padding=get_padding(kernel_size, dilation[0]),
256
- # )
257
- # ),
258
- # weight_norm(
259
- # Conv1d(
260
- # channels,
261
- # channels,
262
- # kernel_size,
263
- # 1,
264
- # dilation=dilation[1],
265
- # padding=get_padding(kernel_size, dilation[1]),
266
- # )
267
- # ),
268
- # weight_norm(
269
- # Conv1d(
270
- # channels,
271
- # channels,
272
- # kernel_size,
273
- # 1,
274
- # dilation=dilation[2],
275
- # padding=get_padding(kernel_size, dilation[2]),
276
- # )
277
- # ),
278
- # ]
279
- # )
280
- # self.convs1.apply(init_weights)
281
-
282
- # self.convs2 = nn.ModuleList(
283
- # [
284
- # weight_norm(
285
- # Conv1d(
286
- # channels,
287
- # channels,
288
- # kernel_size,
289
- # 1,
290
- # dilation=1,
291
- # padding=get_padding(kernel_size, 1),
292
- # )
293
- # ),
294
- # weight_norm(
295
- # Conv1d(
296
- # channels,
297
- # channels,
298
- # kernel_size,
299
- # 1,
300
- # dilation=1,
301
- # padding=get_padding(kernel_size, 1),
302
- # )
303
- # ),
304
- # weight_norm(
305
- # Conv1d(
306
- # channels,
307
- # channels,
308
- # kernel_size,
309
- # 1,
310
- # dilation=1,
311
- # padding=get_padding(kernel_size, 1),
312
- # )
313
- # ),
314
- # ]
315
- # )
316
- # self.convs2.apply(init_weights)
317
-
318
- # def forward(self, x):
319
- # for c1, c2 in zip(self.convs1, self.convs2):
320
- # xt = F.leaky_relu(x, LRELU_SLOPE)
321
- # xt = c1(xt)
322
- # xt = F.leaky_relu(xt, LRELU_SLOPE)
323
- # xt = c2(xt)
324
- # x = xt + x
325
- # return x
326
-
327
- # def remove_weight_norm(self):
328
- # for l in self.convs1:
329
- # remove_weight_norm(l)
330
- # for l in self.convs2:
331
- # remove_weight_norm(l)
332
-
333
- # class Generator(torch.nn.Module):
334
- # def __init__(self, h):
335
- # super(Generator, self).__init__()
336
- # self.h = h
337
- # self.num_kernels = len(h.resblock_kernel_sizes)
338
- # self.num_upsamples = len(h.upsample_rates)
339
- # self.conv_pre = weight_norm(
340
- # Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
341
- # )
342
- # resblock = ResBlock
343
-
344
- # self.ups = nn.ModuleList()
345
- # for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
346
- # self.ups.append(
347
- # weight_norm(
348
- # ConvTranspose1d(
349
- # h.upsample_initial_channel // (2**i),
350
- # h.upsample_initial_channel // (2 ** (i + 1)),
351
- # k,
352
- # u,
353
- # padding=(k - u) // 2,
354
- # )
355
- # )
356
- # )
357
-
358
- # self.resblocks = nn.ModuleList()
359
- # for i in range(len(self.ups)):
360
- # ch = h.upsample_initial_channel // (2 ** (i + 1))
361
- # for j, (k, d) in enumerate(
362
- # zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
363
- # ):
364
- # self.resblocks.append(resblock(h, ch, k, d))
365
-
366
- # self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
367
- # self.ups.apply(init_weights)
368
- # self.conv_post.apply(init_weights)
369
-
370
- # def forward(self, x):
371
- # x = self.conv_pre(x)
372
- # for i in range(self.num_upsamples):
373
- # x = F.leaky_relu(x, LRELU_SLOPE)
374
- # x = self.ups[i](x)
375
- # xs = None
376
- # for j in range(self.num_kernels):
377
- # if xs is None:
378
- # xs = self.resblocks[i * self.num_kernels + j](x)
379
- # else:
380
- # xs += self.resblocks[i * self.num_kernels + j](x)
381
- # x = xs / self.num_kernels
382
- # x = F.leaky_relu(x)
383
- # x = self.conv_post(x)
384
- # x = torch.tanh(x)
385
-
386
- # return x
387
-
388
- # def remove_weight_norm(self):
389
- # print("Removing weight norm...")
390
- # for l in self.ups:
391
- # remove_weight_norm(l)
392
- # for l in self.resblocks:
393
- # l.remove_weight_norm()
394
- # remove_weight_norm(self.conv_pre)
395
- # remove_weight_norm(self.conv_post)