suayptalha commited on
Commit
64d52e0
1 Parent(s): 94f36cd

Create modeling_minGRULM.py

Browse files
Files changed (1) hide show
  1. modeling_minGRULM.py +161 -0
modeling_minGRULM.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel, AutoModelForCausalLM, AutoTokenizer
4
+ from transformers.modeling_outputs import CausalLMOutputWithPast
5
+ from torch.nn import CrossEntropyLoss
6
+ from typing import Optional
7
+ import os
8
+ from .configuration_minGRULM import MinGRULMConfig
9
+ from minGRU_pytorch.minGRULM import minGRULM
10
+
11
+
12
+ class MinGRULMWrapped(nn.Module):
13
+ def __init__(self, min_gru_model):
14
+ super().__init__()
15
+ self.min_gru_model = min_gru_model
16
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ def forward(self, *args, **kwargs):
19
+ args = [arg.to(self.device) if isinstance(arg, torch.Tensor) else arg for arg in args]
20
+ kwargs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in kwargs.items()}
21
+ return self.min_gru_model(*args, **kwargs)
22
+
23
+ def to(self, device):
24
+ self.device = device
25
+ self.min_gru_model.to(device)
26
+ return self
27
+
28
+
29
+
30
+ class MinGRULMPreTrainedModel(PreTrainedModel):
31
+ config_class = MinGRULMConfig
32
+ base_model_prefix = "model"
33
+
34
+ def _init_weights(self, module):
35
+ std = self.config.initializer_range
36
+ if isinstance(module, nn.Linear):
37
+ module.weight.data.normal_(mean=0.0, std=std)
38
+ if module.bias is not None:
39
+ module.bias.data.zero_()
40
+ elif isinstance(module, nn.Embedding):
41
+ module.weight.data.normal_(mean=0.0, std=std)
42
+ if module.padding_idx is not None:
43
+ module.weight.data[module.padding_idx].zero_()
44
+ elif isinstance(module, nn.LayerNorm):
45
+ module.bias.data.zero_()
46
+ module.weight.data.fill_(1.0)
47
+
48
+ for name, param in module.named_parameters():
49
+ if torch.isnan(param).any():
50
+ print(f"NaN detected in parameter {name}. Replacing with a safe number.")
51
+ param.data = torch.nan_to_num(param.data, nan=1e-6)
52
+
53
+
54
+ class MinGRULMForCausalLM(PreTrainedModel):
55
+ config_class = MinGRULMConfig
56
+ base_model_prefix = "model"
57
+
58
+ def __init__(self, config: MinGRULMConfig):
59
+ super().__init__(config)
60
+
61
+ raw_min_gru = minGRULM(
62
+ num_tokens=config.vocab_size,
63
+ dim=config.d_model,
64
+ depth=config.n_layer,
65
+ ff_mult=config.ff_mult,
66
+ min_gru_expansion=config.min_gru_expansion,
67
+ enable_conv=config.enable_conv,
68
+ )
69
+ self.model = MinGRULMWrapped(raw_min_gru)
70
+
71
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
72
+
73
+ self.post_init()
74
+
75
+ def post_init(self):
76
+ super().post_init()
77
+ self.tie_weights()
78
+
79
+ def tie_weights(self):
80
+ self.lm_head.weight = self.model.min_gru_model.token_emb.weight
81
+
82
+ def get_input_embeddings(self):
83
+ return self.model.min_gru_model.token_emb
84
+
85
+ def set_input_embeddings(self, value):
86
+ self.model.min_gru_model.token_emb = value
87
+
88
+ def get_output_embeddings(self):
89
+ return self.lm_head
90
+
91
+ def prepare_inputs_for_generation(self, input_ids: torch.LongTensor, **kwargs):
92
+ return {"input_ids": input_ids, "attention_mask": kwargs.get("attention_mask", None)}
93
+
94
+ def forward(self, input_ids: torch.LongTensor, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = True, **kwargs):
95
+ logits = self.model(input_ids)
96
+
97
+ if torch.isnan(logits).any():
98
+ print("NaN detected in logits! Replacing with a safe number.")
99
+ logits = torch.nan_to_num(logits, nan=1e-6)
100
+
101
+ loss = None
102
+ if labels is not None:
103
+ shift_logits = logits[..., :-1, :].contiguous()
104
+ shift_labels = labels[..., 1:].contiguous()
105
+ loss_fct = CrossEntropyLoss()
106
+ loss = loss_fct(
107
+ shift_logits.view(-1, self.config.vocab_size),
108
+ shift_labels.view(-1),
109
+ )
110
+
111
+ if torch.isnan(loss).any():
112
+ print("NaN detected in loss! Replacing with a safe number.")
113
+ loss = torch.nan_to_num(loss, nan=1e-6)
114
+
115
+ if not return_dict:
116
+ return (loss, logits) if loss is not None else (logits,)
117
+
118
+ return CausalLMOutputWithPast(
119
+ loss=loss,
120
+ logits=logits,
121
+ )
122
+
123
+
124
+ @classmethod
125
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
126
+ """
127
+ Load model from a pretrained checkpoint.
128
+ """
129
+ model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
130
+ return model
131
+
132
+ def save_pretrained(self, save_directory, safe_serialization: Optional[bool] = True, **kwargs):
133
+ """
134
+ Save the model and configuration to a directory.
135
+
136
+ Args:
137
+ save_directory (str): Directory to save the model.
138
+ safe_serialization (bool, optional): Whether to use safe serialization. Defaults to True.
139
+ kwargs: Additional arguments like max_shard_size (ignored in this implementation).
140
+ """
141
+ import os
142
+ os.makedirs(save_directory, exist_ok=True)
143
+
144
+ if safe_serialization:
145
+ print("Saving with safe serialization.")
146
+
147
+ state_dict = {}
148
+
149
+ for name, param in self.model.min_gru_model.named_parameters():
150
+ state_dict[f"model.{name}"] = param
151
+
152
+ for name, param in self.classifier.named_parameters():
153
+ state_dict[f"classifier.{name}"] = param
154
+
155
+ state_dict['config'] = self.config.__dict__
156
+ torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
157
+
158
+ self.config.save_pretrained(save_directory)
159
+ else:
160
+ print("Saving without safe serialization.")
161
+ super().save_pretrained(save_directory)