Deci
/

itay-levy commited on
Commit
7a2815a
·
1 Parent(s): 1aed8fa

Upload modeling_decicoder.py with huggingface_hub (#4)

Browse files

- Upload modeling_decicoder.py with huggingface_hub (2ac67e2c0ac46c68d36924967112c645ab367a2b)

Files changed (1) hide show
  1. modeling_decicoder.py +246 -0
modeling_decicoder.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright and license here
3
+ """ PyTorch DeciCoder model."""
4
+ import math
5
+ from typing import Optional, Tuple
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
12
+ repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
13
+ from transformers.utils import add_start_docstrings
14
+
15
+ from .configuration_decicoder import DeciCoderConfig
16
+
17
+ _CONFIG_FOR_DOC = "DeciCoderConfig"
18
+
19
+
20
+ class DeciCoderAttention(LlamaAttention):
21
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
22
+
23
+ def __init__(self, config: DeciCoderConfig):
24
+ nn.Module.__init__(self)
25
+ self.config = config
26
+ self.hidden_size = config.hidden_size
27
+ self.num_heads = config.num_attention_heads
28
+ self.head_dim = self.hidden_size // self.num_heads
29
+ self.num_key_value_heads = config.num_key_value_heads
30
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
31
+ self.pretraining_tp = config.pretraining_tp
32
+ self.max_position_embeddings = config.max_position_embeddings
33
+
34
+ if (self.head_dim * self.num_heads) != self.hidden_size:
35
+ raise ValueError(
36
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
37
+ f" and `num_heads`: {self.num_heads})."
38
+ )
39
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
40
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
41
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
42
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
43
+
44
+ self.naive_attention_prefill = config.naive_attention_prefill
45
+ self.naive_attention_decode_batched = config.naive_attention_decode_batched
46
+ self.naive_attention_decode_single = config.naive_attention_decode_single
47
+ self._init_rope()
48
+
49
+ def forward(
50
+ self,
51
+ hidden_states: torch.Tensor,
52
+ attention_mask: Optional[torch.Tensor] = None,
53
+ position_ids: Optional[torch.LongTensor] = None,
54
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
55
+ output_attentions: bool = False,
56
+ use_cache: bool = False,
57
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
58
+ bsz, q_len, _ = hidden_states.size()
59
+ if past_key_value is None:
60
+ is_decode = False
61
+ else:
62
+ is_decode = True
63
+ if self.pretraining_tp > 1:
64
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
65
+ query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
66
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
67
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
68
+
69
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
70
+ query_states = torch.cat(query_states, dim=-1)
71
+
72
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
73
+ key_states = torch.cat(key_states, dim=-1)
74
+
75
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
76
+ value_states = torch.cat(value_states, dim=-1)
77
+
78
+ else:
79
+ query_states = self.q_proj(hidden_states)
80
+ key_states = self.k_proj(hidden_states)
81
+ value_states = self.v_proj(hidden_states)
82
+
83
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
84
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
85
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
86
+
87
+ kv_seq_len = key_states.shape[-2]
88
+ if past_key_value is not None:
89
+ kv_seq_len += past_key_value[0].shape[-2]
90
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
91
+
92
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
93
+
94
+ if past_key_value is not None:
95
+ # reuse k, v, self_attention
96
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
97
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
98
+
99
+ past_key_value = (key_states, value_states) if use_cache else None
100
+
101
+ # repeat k/v heads if n_kv_heads < n_heads
102
+ if is_decode:
103
+ query_states = query_states.view(bsz, self.num_key_value_heads, self.num_key_value_groups, self.head_dim)
104
+ if self.naive_attention_decode_batched and bsz > 1 or self.naive_attention_decode_single and bsz == 1:
105
+ attn_weights = (query_states @ key_states.transpose(-2, -1)) / math.sqrt(key_states.size(-1))
106
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
107
+ if attention_mask is not None:
108
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
109
+ raise ValueError(
110
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
111
+ )
112
+ attn_weights = attn_weights + attention_mask
113
+
114
+ attn_output = torch.matmul(attn_weights, value_states)
115
+ else:
116
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=False,
117
+ dropout_p=0.0)
118
+ attn_output = attn_output.contiguous().view(bsz, q_len, self.hidden_size)
119
+
120
+ else:
121
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
122
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
123
+
124
+ if not self.naive_attention_prefill:
125
+ attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True,
126
+ dropout_p=0.0)
127
+ else:
128
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
129
+ # attn_weights = (query_states @ key_states.transpose(-2, -1)) / math.sqrt(key_states.size(-1))
130
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
131
+ raise ValueError(
132
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
133
+ f" {attn_weights.size()}"
134
+ )
135
+
136
+ if attention_mask is not None:
137
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
138
+ raise ValueError(
139
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
140
+ )
141
+ attn_weights = attn_weights + attention_mask
142
+
143
+ # upcast attention to fp32
144
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
145
+ attn_output = torch.matmul(attn_weights, value_states)
146
+
147
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
148
+ raise ValueError(
149
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
150
+ f" {attn_output.size()}"
151
+ )
152
+
153
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
154
+ # attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
155
+
156
+ if self.pretraining_tp > 1:
157
+ attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
158
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
159
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
160
+ else:
161
+ attn_output = self.o_proj(attn_output)
162
+
163
+ if not output_attentions:
164
+ attn_weights = None
165
+
166
+ return attn_output, attn_weights, past_key_value
167
+
168
+
169
+ class DeciCoderDecoderLayer(LlamaDecoderLayer):
170
+ def __init__(self, config: DeciCoderConfig):
171
+ nn.Module.__init__(self)
172
+ self.hidden_size = config.hidden_size
173
+ self.self_attn = DeciCoderAttention(config=config)
174
+ self.mlp = LlamaMLP(config)
175
+ self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
176
+ self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
177
+
178
+
179
+ @add_start_docstrings(
180
+ "The bare DeciCoder Model outputting raw hidden-states without any specific head on top.",
181
+ LLAMA_START_DOCSTRING,
182
+ )
183
+ class DeciCoderPreTrainedModel(LlamaPreTrainedModel):
184
+ config_class = DeciCoderConfig
185
+ _no_split_modules = ["DeciCoderDecoderLayer"]
186
+ _keys_to_ignore_on_load_missing = ["self_attn.rotary_emb.inv_freq"]
187
+
188
+
189
+ @add_start_docstrings(
190
+ "The bare DeciCoder Model outputting raw hidden-states without any specific head on top.",
191
+ LLAMA_START_DOCSTRING,
192
+ )
193
+ class DeciCoderModel(LlamaModel, DeciCoderPreTrainedModel):
194
+ """
195
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciCoderDecoderLayer`]
196
+
197
+ Args:
198
+ config: DeciCoderConfig
199
+ """
200
+
201
+ def __init__(self, config: DeciCoderConfig):
202
+ DeciCoderPreTrainedModel.__init__(self, config)
203
+ self.padding_idx = config.pad_token_id
204
+ self.vocab_size = config.vocab_size
205
+
206
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
207
+ self.layers = nn.ModuleList([DeciCoderDecoderLayer(config) for _ in range(config.num_hidden_layers)])
208
+ self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
209
+
210
+ self.gradient_checkpointing = False
211
+ # Initialize weights and apply final processing
212
+ self.post_init()
213
+
214
+ def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
215
+ self._validate_config_supports_attention_mask(attention_mask, input_shape, past_key_values_length)
216
+ return LlamaModel._prepare_decoder_attention_mask(
217
+ self, attention_mask, input_shape, inputs_embeds, past_key_values_length)
218
+
219
+ def _validate_config_supports_attention_mask(self, attention_mask, input_shape, past_key_values_length):
220
+ is_decode = past_key_values_length > 0
221
+ if not torch.all(torch.eq(attention_mask, 1)).item():
222
+ if is_decode:
223
+ if input_shape[0] == 1 and not self.config.naive_attention_decode_single:
224
+ raise ValueError(
225
+ "For support of custom attention masks please set naive_attention_decode_single to True in the "
226
+ "config")
227
+ elif input_shape[0] > 1 and not self.config.naive_attention_decode_batched:
228
+ raise ValueError(
229
+ "For support of custom attention masks please set naive_attention_decode_batched to True in the"
230
+ "config")
231
+ else:
232
+ if not self.config.naive_attention_prefill:
233
+ raise ValueError("For support of custom attention masks please set naive_attention_prefill to "
234
+ "True in the config")
235
+
236
+
237
+ class DeciCoderForCausalLM(LlamaForCausalLM, DeciCoderPreTrainedModel):
238
+ def __init__(self, config):
239
+ DeciCoderPreTrainedModel.__init__(self, config)
240
+ self.model = DeciCoderModel(config)
241
+ self.pretraining_tp = config.pretraining_tp
242
+ self.vocab_size = config.vocab_size
243
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
244
+
245
+ # Initialize weights and apply final processing
246
+ self.post_init()