from dataclasses import dataclass from typing import Optional, Tuple, Union import torch import torch.nn as nn from transformers import HubertForSequenceClassification from transformers.activations import ACT2FN from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.file_utils import ModelOutput from transformers.modeling_outputs import BaseModelOutput from transformers.models.hubert import HubertConfig from transformers.models.hubert.modeling_hubert import HubertPreTrainedModel, HubertFeatureEncoder, \ HubertFeatureProjection, _compute_mask_indices, \ HubertPositionalConvEmbedding, HubertAttention import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin ###### # ####### _HIDDEN_STATES_START_POSITION = 1 # General docstring _CONFIG_FOR_DOC = "HubertConfig" # Base docstring _CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft" _EXPECTED_OUTPUT_SHAPE = [1, 292, 768] # CTC docstring _CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'" _CTC_EXPECTED_LOSS = 22.68 # Audio class docstring _SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks" _SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'" _SEQ_CLASS_EXPECTED_LOSS = 8.53 HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [ "facebook/hubert-base-ls960", # See all Hubert models at https://huggingface.co/models?filter=hubert ] # SwiGLU function # From """GLU Variants Improve Transformer """ # https://doi.org/10.48550/arXiv.2002.05202 class SwiGLU(nn.Module): def forward(self, x): x, gate = x.chunk(2, dim=-1) return F.silu(gate) * x @dataclass class SpeechClassifierOutput(ModelOutput): """ Speech Classifier Output dataclass """ loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class ExHuBERTFeedForward(nn.Module): def __init__(self, config): super().__init__() self.intermediate_dropout = nn.Dropout(config.activation_dropout) self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size) if isinstance(config.hidden_act, str): self.intermediate_act_fn = ACT2FN[config.hidden_act] else: self.intermediate_act_fn = config.hidden_act self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size) self.output_dropout = nn.Dropout(config.hidden_dropout) def forward(self, hidden_states): hidden_states = self.intermediate_dense(hidden_states) hidden_states = self.intermediate_act_fn(hidden_states) hidden_states = self.intermediate_dropout(hidden_states) hidden_states = self.output_dense(hidden_states) hidden_states = self.output_dropout(hidden_states) return hidden_states # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->Hubert class ExHuBERTEncoderLayer(nn.Module): def __init__(self, config): super().__init__() self.attention = HubertAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.feed_forward = ExHuBERTFeedForward(config) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.gate_bb_linear = nn.Linear(config.hidden_size, config.hidden_size) def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, ): attn_residual = hidden_states hidden_states = self.layer_norm(hidden_states) hidden_states, attn_weights, _ = self.attention( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = self.dropout(hidden_states) hidden_states = attn_residual + hidden_states hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states)) hidden_states = self.gate_bb_linear(hidden_states) outputs = (hidden_states,) if output_attentions: outputs += (attn_weights,) return outputs class ExHuBERTEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.pos_conv_embed = HubertPositionalConvEmbedding(config) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList( [ExHuBERTEncoderLayer(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False def forward( self, hidden_states, attention_mask=None, output_attentions=False, output_hidden_states=False, return_dict=True, ): all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None if attention_mask is not None: # make sure padded tokens are not attended to expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 # extend attention_mask attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min attention_mask = attention_mask.expand( attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings hidden_states = self.dropout(hidden_states) deepspeed_zero3_is_enabled = is_deepspeed_zero3_enabled() skip = torch.zeros_like(hidden_states) skip_bool = False for layer in self.layers: if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) dropout_probability = torch.rand([]) # skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False skip_the_layer = False if not skip_the_layer or deepspeed_zero3_is_enabled: # under deepspeed zero3 all gpus must run in sync # XXX: could optimize this like synced_gpus in generate_utils but not sure if it's worth the code complication if self.gradient_checkpointing and self.training: # create gradient checkpointing function def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs, output_attentions) return custom_forward layer_outputs = torch.utils.checkpoint.checkpoint( create_custom_forward(layer), hidden_states, attention_mask, ) else: layer_outputs = layer( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions ) hidden_states = layer_outputs[0] if skip_the_layer: layer_outputs = (None, None) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if skip_bool is True: hidden_states = hidden_states + skip skip_bool = False else: skip = hidden_states skip_bool = True hidden_states = self.layer_norm(hidden_states) if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_self_attentions, ) class ExHuBERT_model_(HubertPreTrainedModel): def __init__(self, config: HubertConfig): super().__init__(config) setattr(config, 'num_hidden_layers', 48) self.config = config self.feature_extractor = HubertFeatureEncoder(config) self.feature_projection = HubertFeatureProjection(config) if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_()) self.encoder = ExHuBERTEncoder(config) # Initialize weights and apply final processing self.post_init() # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, ): """ Masks extracted features along time axis and/or along feature axis according to [SpecAugment](https://arxiv.org/abs/1904.08779). """ # `config.apply_spec_augment` can set masking to False if not getattr(self.config, "apply_spec_augment", True): return hidden_states # generate indices & apply SpecAugment along time axis batch_size, sequence_length, hidden_size = hidden_states.size() if mask_time_indices is not None: # apply SpecAugment along time axis with given mask_time_indices hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) elif self.config.mask_time_prob > 0 and self.training: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), mask_prob=self.config.mask_time_prob, mask_length=self.config.mask_time_length, attention_mask=attention_mask, min_masks=self.config.mask_time_min_masks, ) mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) if self.config.mask_feature_prob > 0 and self.training: # generate indices & apply SpecAugment along feature axis mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), mask_prob=self.config.mask_feature_prob, mask_length=self.config.mask_feature_length, min_masks=self.config.mask_feature_min_masks, ) mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) hidden_states[mask_feature_indices] = 0 return hidden_states def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, mask_time_indices: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutput]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict extract_features = self.feature_extractor(input_values) extract_features = extract_features.transpose(1, 2) if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) hidden_states = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) encoder_outputs = self.encoder( hidden_states, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) hidden_states = encoder_outputs[0] if not return_dict: return (hidden_states,) + encoder_outputs[1:] return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=encoder_outputs.hidden_states, attentions=encoder_outputs.attentions, ) class ExHuBERT(HubertPreTrainedModel,PyTorchModelHubMixin): def __init__(self, config): super().__init__(config) setattr(config, "num_labels", 6) if hasattr(config, "add_adapter") and config.add_adapter: raise ValueError( "Sequence classification does not support the use of Hubert adapters (config.add_adapter=True)" ) self.hubert = ExHuBERT_model_(config) num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings if config.use_weighted_layer_sum: self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) # Initialize weights and apply final processing self.post_init() def freeze_feature_encoder(self): """ Calling this function will disable the gradient computation for the feature encoder so that its parameter will not be updated during training. """ self.hubert.feature_extractor._freeze_parameters() def freeze_base_model(self): """ Calling this function will disable the gradient computation for the base model so that its parameters will not be updated during training. Only the classification head will be updated. """ for param in self.hubert.parameters(): param.requires_grad = False def forward( self, input_values: Optional[torch.Tensor], attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, labels: Optional[torch.Tensor] = None, ) -> Union[Tuple, SpeechClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states outputs = self.hubert( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) if self.config.use_weighted_layer_sum: hidden_states = outputs[_HIDDEN_STATES_START_POSITION] hidden_states = torch.stack(hidden_states, dim=1) norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) else: hidden_states = outputs[0] hidden_states = self.projector(hidden_states) if attention_mask is None: pooled_output = hidden_states.mean(dim=1) else: padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) hidden_states[~padding_mask] = 0.0 pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) logits = self.classifier(pooled_output) loss = None if not return_dict: output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:] return ((loss,) + output) if loss is not None else output return SpeechClassifierOutput( loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) def freeze_og_encoder(self): for param in self.hubert.encoder.layers[::2].parameters(): param.requires_grad = False def print_trainable_parameters(model): ''' prints all trainable parameters of a model ''' trainable_params = 0 all_param = 0 for _, param in model.named_parameters(): all_param += param.numel() if param.requires_grad: trainable_params += param.numel() print( f"trainable params: {trainable_params:,d} || all params: {all_param:,d} || trainable%: {100 * trainable_params / all_param:.2f}" )