amiriparian commited on
Commit
cf70fdf
·
verified ·
1 Parent(s): fec72ec

Upload ExHuBERT

Browse files
Files changed (2) hide show
  1. ExHuBERT_model.py +451 -0
  2. config.json +3 -0
ExHuBERT_model.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import HubertForSequenceClassification
7
+ from transformers.activations import ACT2FN
8
+ from transformers.deepspeed import is_deepspeed_zero3_enabled
9
+ from transformers.file_utils import ModelOutput
10
+ from transformers.modeling_outputs import BaseModelOutput
11
+ from transformers.models.hubert import HubertConfig
12
+ from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel, HubertFeatureEncoder, \
13
+ HubertFeatureProjection, _compute_mask_indices, \
14
+ HubertPositionalConvEmbedding, HubertAttention
15
+ import torch.nn.functional as F
16
+ from huggingface_hub import PyTorchModelHubMixin
17
+
18
+ ######
19
+ #
20
+ #######
21
+
22
+
23
+
24
+ _HIDDEN_STATES_START_POSITION = 1
25
+
26
+ # General docstring
27
+ _CONFIG_FOR_DOC = "HubertConfig"
28
+
29
+ # Base docstring
30
+ _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
31
+ _EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
32
+
33
+ # CTC docstring
34
+ _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
35
+ _CTC_EXPECTED_LOSS = 22.68
36
+
37
+ # Audio class docstring
38
+ _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
39
+ _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
40
+ _SEQ_CLASS_EXPECTED_LOSS = 8.53
41
+
42
+ HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "facebook/hubert-base-ls960",
44
+ # See all Hubert models at https://huggingface.co/models?filter=hubert
45
+ ]
46
+
47
+
48
+ # SwiGLU function
49
+ # From """GLU Variants Improve Transformer """
50
+ # https://doi.org/10.48550/arXiv.2002.05202
51
+ class SwiGLU(nn.Module):
52
+ def forward(self, x):
53
+ x, gate = x.chunk(2, dim=-1)
54
+ return F.silu(gate) * x
55
+
56
+
57
+ @dataclass
58
+ class SpeechClassifierOutput(ModelOutput):
59
+ """
60
+ Speech Classifier Output dataclass
61
+ """
62
+ loss: Optional[torch.FloatTensor] = None
63
+ logits: torch.FloatTensor = None
64
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
65
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
66
+
67
+
68
+ class ExHuBERTFeedForward(nn.Module):
69
+ def __init__(self, config):
70
+ super().__init__()
71
+ self.intermediate_dropout = nn.Dropout(config.activation_dropout)
72
+
73
+ self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
74
+ if isinstance(config.hidden_act, str):
75
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
76
+ else:
77
+ self.intermediate_act_fn = config.hidden_act
78
+
79
+ self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
80
+ self.output_dropout = nn.Dropout(config.hidden_dropout)
81
+
82
+ def forward(self, hidden_states):
83
+ hidden_states = self.intermediate_dense(hidden_states)
84
+ hidden_states = self.intermediate_act_fn(hidden_states)
85
+ hidden_states = self.intermediate_dropout(hidden_states)
86
+
87
+ hidden_states = self.output_dense(hidden_states)
88
+ hidden_states = self.output_dropout(hidden_states)
89
+ return hidden_states
90
+
91
+
92
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert
93
+ class ExHuBERTEncoderLayer(nn.Module):
94
+ def __init__(self, config):
95
+ super().__init__()
96
+ self.attention = HubertAttention(
97
+ embed_dim=config.hidden_size,
98
+ num_heads=config.num_attention_heads,
99
+ dropout=config.attention_dropout,
100
+ is_decoder=False,
101
+ )
102
+ self.dropout = nn.Dropout(config.hidden_dropout)
103
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
104
+ self.feed_forward = ExHuBERTFeedForward(config)
105
+ self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
106
+ self.gate_bb_linear = nn.Linear(config.hidden_size, config.hidden_size)
107
+
108
+ def forward(
109
+ self,
110
+ hidden_states: torch.Tensor,
111
+ attention_mask: Optional[torch.Tensor] = None,
112
+ output_attentions: bool = False,
113
+ ):
114
+ attn_residual = hidden_states
115
+ hidden_states = self.layer_norm(hidden_states)
116
+ hidden_states, attn_weights, _ = self.attention(
117
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
118
+ )
119
+ hidden_states = self.dropout(hidden_states)
120
+ hidden_states = attn_residual + hidden_states
121
+ hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
122
+
123
+ hidden_states = self.gate_bb_linear(hidden_states)
124
+ outputs = (hidden_states,)
125
+
126
+ if output_attentions:
127
+ outputs += (attn_weights,)
128
+
129
+ return outputs
130
+
131
+
132
+ class ExHuBERTEncoder(nn.Module):
133
+ def __init__(self, config):
134
+ super().__init__()
135
+ self.config = config
136
+ self.pos_conv_embed = HubertPositionalConvEmbedding(config)
137
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
138
+ self.dropout = nn.Dropout(config.hidden_dropout)
139
+ self.layers = nn.ModuleList(
140
+ [ExHuBERTEncoderLayer(config) for _ in range(config.num_hidden_layers)]
141
+ )
142
+ self.gradient_checkpointing = False
143
+
144
+ def forward(
145
+ self,
146
+ hidden_states,
147
+ attention_mask=None,
148
+ output_attentions=False,
149
+ output_hidden_states=False,
150
+ return_dict=True,
151
+ ):
152
+ all_hidden_states = () if output_hidden_states else None
153
+ all_self_attentions = () if output_attentions else None
154
+
155
+ if attention_mask is not None:
156
+ # make sure padded tokens are not attended to
157
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
158
+ hidden_states[~expand_attention_mask] = 0
159
+
160
+ # extend attention_mask
161
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
162
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
163
+ attention_mask = attention_mask.expand(
164
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
165
+ )
166
+
167
+ position_embeddings = self.pos_conv_embed(hidden_states)
168
+ hidden_states = hidden_states + position_embeddings
169
+ hidden_states = self.dropout(hidden_states)
170
+
171
+ deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled()
172
+
173
+ skip = torch.zeros_like(hidden_states)
174
+ skip_bool = False
175
+ for layer in self.layers:
176
+
177
+ if output_hidden_states:
178
+ all_hidden_states = all_hidden_states + (hidden_states,)
179
+
180
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
181
+ dropout_probability = torch.rand([])
182
+
183
+ # skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
184
+ skip_the_layer = False
185
+ if not skip_the_layer or deepspeed_zero3_is_enabled:
186
+ # under deepspeed zero3 all gpus must run in sync
187
+ # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication
188
+ if self.gradient_checkpointing and self.training:
189
+ # create gradient checkpointing function
190
+ def create_custom_forward(module):
191
+ def custom_forward(*inputs):
192
+ return module(*inputs, output_attentions)
193
+
194
+ return custom_forward
195
+
196
+ layer_outputs = torch.utils.checkpoint.checkpoint(
197
+ create_custom_forward(layer),
198
+ hidden_states,
199
+ attention_mask,
200
+ )
201
+ else:
202
+ layer_outputs = layer(
203
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
204
+ )
205
+ hidden_states = layer_outputs[0]
206
+
207
+ if skip_the_layer:
208
+ layer_outputs = (None, None)
209
+
210
+ if output_attentions:
211
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
212
+ if skip_bool is True:
213
+ hidden_states = hidden_states + skip
214
+
215
+ skip_bool = False
216
+ else:
217
+ skip = hidden_states
218
+ skip_bool = True
219
+
220
+ hidden_states = self.layer_norm(hidden_states)
221
+
222
+ if output_hidden_states:
223
+ all_hidden_states = all_hidden_states + (hidden_states,)
224
+
225
+ if not return_dict:
226
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
227
+ return BaseModelOutput(
228
+ last_hidden_state=hidden_states,
229
+ hidden_states=all_hidden_states,
230
+ attentions=all_self_attentions,
231
+ )
232
+
233
+
234
+ class ExHuBERT_model_(HubertPreTrainedModel):
235
+ def __init__(self, config: HubertConfig):
236
+ super().__init__(config)
237
+ setattr(config, 'num_hidden_layers', 48)
238
+ self.config = config
239
+ self.feature_extractor = HubertFeatureEncoder(config)
240
+ self.feature_projection = HubertFeatureProjection(config)
241
+
242
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
243
+ self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
244
+
245
+ self.encoder = ExHuBERTEncoder(config)
246
+
247
+ # Initialize weights and apply final processing
248
+ self.post_init()
249
+
250
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
251
+ def _mask_hidden_states(
252
+ self,
253
+ hidden_states: torch.FloatTensor,
254
+ mask_time_indices: Optional[torch.FloatTensor] = None,
255
+ attention_mask: Optional[torch.LongTensor] = None,
256
+ ):
257
+ """
258
+ Masks extracted features along time axis and/or along feature axis according to
259
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
260
+ """
261
+
262
+ # `config.apply_spec_augment` can set masking to False
263
+ if not getattr(self.config, "apply_spec_augment", True):
264
+ return hidden_states
265
+
266
+ # generate indices & apply SpecAugment along time axis
267
+ batch_size, sequence_length, hidden_size = hidden_states.size()
268
+
269
+ if mask_time_indices is not None:
270
+ # apply SpecAugment along time axis with given mask_time_indices
271
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
272
+ elif self.config.mask_time_prob > 0 and self.training:
273
+ mask_time_indices = _compute_mask_indices(
274
+ (batch_size, sequence_length),
275
+ mask_prob=self.config.mask_time_prob,
276
+ mask_length=self.config.mask_time_length,
277
+ attention_mask=attention_mask,
278
+ min_masks=self.config.mask_time_min_masks,
279
+ )
280
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
281
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
282
+
283
+ if self.config.mask_feature_prob > 0 and self.training:
284
+ # generate indices & apply SpecAugment along feature axis
285
+ mask_feature_indices = _compute_mask_indices(
286
+ (batch_size, hidden_size),
287
+ mask_prob=self.config.mask_feature_prob,
288
+ mask_length=self.config.mask_feature_length,
289
+ min_masks=self.config.mask_feature_min_masks,
290
+ )
291
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
292
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
293
+ hidden_states[mask_feature_indices] = 0
294
+
295
+ return hidden_states
296
+
297
+ def forward(
298
+ self,
299
+ input_values: Optional[torch.Tensor],
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ mask_time_indices: Optional[torch.FloatTensor] = None,
302
+ output_attentions: Optional[bool] = None,
303
+ output_hidden_states: Optional[bool] = None,
304
+ return_dict: Optional[bool] = None,
305
+ ) -> Union[Tuple, BaseModelOutput]:
306
+
307
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
308
+ output_hidden_states = (
309
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
310
+ )
311
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
312
+
313
+ extract_features = self.feature_extractor(input_values)
314
+ extract_features = extract_features.transpose(1, 2)
315
+
316
+ if attention_mask is not None:
317
+ # compute reduced attention_mask corresponding to feature vectors
318
+ attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
319
+
320
+ hidden_states = self.feature_projection(extract_features)
321
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
322
+
323
+ encoder_outputs = self.encoder(
324
+ hidden_states,
325
+ attention_mask=attention_mask,
326
+ output_attentions=output_attentions,
327
+ output_hidden_states=output_hidden_states,
328
+ return_dict=return_dict,
329
+ )
330
+
331
+ hidden_states = encoder_outputs[0]
332
+
333
+ if not return_dict:
334
+ return (hidden_states,) + encoder_outputs[1:]
335
+
336
+ return BaseModelOutput(
337
+ last_hidden_state=hidden_states,
338
+ hidden_states=encoder_outputs.hidden_states,
339
+ attentions=encoder_outputs.attentions,
340
+ )
341
+
342
+
343
+ class ExHuBERT(HubertPreTrainedModel,PyTorchModelHubMixin):
344
+ def __init__(self, config):
345
+ super().__init__(config)
346
+ setattr(config, "num_labels", 6)
347
+ if hasattr(config, "add_adapter") and config.add_adapter:
348
+ raise ValueError(
349
+ "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)"
350
+ )
351
+ self.hubert = ExHuBERT_model_(config)
352
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
353
+ if config.use_weighted_layer_sum:
354
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
355
+ self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
356
+ self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
357
+
358
+ # Initialize weights and apply final processing
359
+ self.post_init()
360
+
361
+ def freeze_feature_encoder(self):
362
+ """
363
+ Calling this function will disable the gradient computation for the feature encoder so that its parameter will
364
+ not be updated during training.
365
+ """
366
+ self.hubert.feature_extractor._freeze_parameters()
367
+
368
+ def freeze_base_model(self):
369
+ """
370
+ Calling this function will disable the gradient computation for the base model so that its parameters will not
371
+ be updated during training. Only the classification head will be updated.
372
+ """
373
+ for param in self.hubert.parameters():
374
+ param.requires_grad = False
375
+
376
+ def forward(
377
+ self,
378
+ input_values: Optional[torch.Tensor],
379
+ attention_mask: Optional[torch.Tensor] = None,
380
+ output_attentions: Optional[bool] = None,
381
+ output_hidden_states: Optional[bool] = None,
382
+ return_dict: Optional[bool] = None,
383
+ labels: Optional[torch.Tensor] = None,
384
+ ) -> Union[Tuple, SpeechClassifierOutput]:
385
+ r"""
386
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
387
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
388
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
389
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
390
+ """
391
+
392
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
393
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
394
+
395
+ outputs = self.hubert(
396
+ input_values,
397
+ attention_mask=attention_mask,
398
+ output_attentions=output_attentions,
399
+ output_hidden_states=output_hidden_states,
400
+ return_dict=return_dict,
401
+ )
402
+
403
+ if self.config.use_weighted_layer_sum:
404
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
405
+ hidden_states = torch.stack(hidden_states, dim=1)
406
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
407
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
408
+ else:
409
+ hidden_states = outputs[0]
410
+
411
+ hidden_states = self.projector(hidden_states)
412
+ if attention_mask is None:
413
+ pooled_output = hidden_states.mean(dim=1)
414
+ else:
415
+ padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
416
+ hidden_states[~padding_mask] = 0.0
417
+ pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
418
+
419
+ logits = self.classifier(pooled_output)
420
+
421
+ loss = None
422
+
423
+ if not return_dict:
424
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
425
+ return ((loss,) + output) if loss is not None else output
426
+
427
+ return SpeechClassifierOutput(
428
+ loss=loss,
429
+ logits=logits,
430
+ hidden_states=outputs.hidden_states,
431
+ attentions=outputs.attentions,
432
+ )
433
+
434
+ def freeze_og_encoder(self):
435
+ for param in self.hubert.encoder.layers[::2].parameters():
436
+ param.requires_grad = False
437
+
438
+ def print_trainable_parameters(model):
439
+ '''
440
+ prints all trainable parameters of a model
441
+ '''
442
+ trainable_params = 0
443
+ all_param = 0
444
+ for _, param in model.named_parameters():
445
+ all_param += param.numel()
446
+ if param.requires_grad:
447
+ trainable_params += param.numel()
448
+ print(
449
+ f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.2f}"
450
+ )
451
+
config.json CHANGED
@@ -6,6 +6,9 @@
6
  "ExHuBERT"
7
  ],
8
  "attention_dropout": 0.1,
 
 
 
9
  "bos_token_id": 1,
10
  "classifier_proj_size": 256,
11
  "conv_bias": true,
 
6
  "ExHuBERT"
7
  ],
8
  "attention_dropout": 0.1,
9
+ "auto_map": {
10
+ "AutoModelForAudioClassification": "ExHuBERT_model.ExHuBERT"
11
+ },
12
  "bos_token_id": 1,
13
  "classifier_proj_size": 256,
14
  "conv_bias": true,