Matt commited on
Commit
8342e98
1 Parent(s): eb30971

Preparations for transition to library

Browse files
Files changed (4) hide show
  1. README.md +3 -3
  2. config.json +9 -1
  3. configuration_RW.py +90 -22
  4. modelling_RW.py → modeling_RW.py +387 -225
README.md CHANGED
@@ -22,6 +22,8 @@ license: apache-2.0
22
  * **It features an architecture optimized for inference**, with FlashAttention ([Dao et al., 2022](https://arxiv.org/abs/2205.14135)) and multiquery ([Shazeer et al., 2019](https://arxiv.org/abs/1911.02150)).
23
  * **It is made available under a permissive Apache 2.0 license allowing for commercial use**, without any royalties or restrictions.
24
 
 
 
25
  ⚠️ **This is a raw, pretrained model, which should be further finetuned for most usecases.** If you are looking for a version better suited to taking generic instructions in a chat format, we recommend taking a look at [Falcon-7B-Instruct](https://huggingface.co/tiiuae/falcon-7b-instruct).
26
 
27
  🔥 **Looking for an even more powerful model?** [Falcon-40B](https://huggingface.co/tiiuae/falcon-40b) is Falcon-7B's big brother!
@@ -39,7 +41,6 @@ pipeline = transformers.pipeline(
39
  model=model,
40
  tokenizer=tokenizer,
41
  torch_dtype=torch.bfloat16,
42
- trust_remote_code=True,
43
  device_map="auto",
44
  )
45
  sequences = pipeline(
@@ -110,7 +111,6 @@ pipeline = transformers.pipeline(
110
  model=model,
111
  tokenizer=tokenizer,
112
  torch_dtype=torch.bfloat16,
113
- trust_remote_code=True,
114
  device_map="auto",
115
  )
116
  sequences = pipeline(
@@ -233,4 +233,4 @@ To learn more about the pretraining dataset, see the 📓 [RefinedWeb paper](htt
233
  Falcon-7B is made available under the Apache 2.0 license.
234
 
235
  ## Contact
236
 
22
  * **It features an architecture optimized for inference**, with FlashAttention ([Dao et al., 2022](https://arxiv.org/abs/2205.14135)) and multiquery ([Shazeer et al., 2019](https://arxiv.org/abs/1911.02150)).
23
  * **It is made available under a permissive Apache 2.0 license allowing for commercial use**, without any royalties or restrictions.
24
 
25
+ ⚠️ Falcon is now available as a core model in the `transformers` library! To use the in-library version, please install the latest version of `transformers` with `pip install git+https://github.com/ huggingface/transformers.git`, then simply remove the `trust_remote_code=True` argument from `from_pretrained()`.
26
+
27
  ⚠️ **This is a raw, pretrained model, which should be further finetuned for most usecases.** If you are looking for a version better suited to taking generic instructions in a chat format, we recommend taking a look at [Falcon-7B-Instruct](https://huggingface.co/tiiuae/falcon-7b-instruct).
28
 
29
  🔥 **Looking for an even more powerful model?** [Falcon-40B](https://huggingface.co/tiiuae/falcon-40b) is Falcon-7B's big brother!
 
41
  model=model,
42
  tokenizer=tokenizer,
43
  torch_dtype=torch.bfloat16,
 
44
  device_map="auto",
45
  )
46
  sequences = pipeline(
 
111
  model=model,
112
  tokenizer=tokenizer,
113
  torch_dtype=torch.bfloat16,
 
114
  device_map="auto",
115
  )
116
  sequences = pipeline(
 
233
  Falcon-7B is made available under the Apache 2.0 license.
234
 
235
  ## Contact
236
config.json CHANGED
@@ -5,6 +5,14 @@
5
  "FalconForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
 
 
 
 
 
 
 
 
8
  "bias": false,
9
  "bos_token_id": 11,
10
  "eos_token_id": 11,
@@ -22,4 +30,4 @@
22
  "transformers_version": "4.27.4",
23
  "use_cache": true,
24
  "vocab_size": 65024
25
- }
 
5
  "FalconForCausalLM"
6
  ],
7
  "attention_dropout": 0.0,
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_RW.RWConfig",
10
+ "AutoModel": "modeling_RW.RWModel",
11
+ "AutoModelForSequenceClassification": "modeling_RW.RWForSequenceClassification",
12
+ "AutoModelForTokenClassification": "modeling_RW.RWForTokenClassification",
13
+ "AutoModelForQuestionAnswering": "modeling_RW.RWForQuestionAnswering",
14
+ "AutoModelForCausalLM": "modeling_RW.RWForCausalLM"
15
+ },
16
  "bias": false,
17
  "bos_token_id": 11,
18
  "eos_token_id": 11,
 
30
  "transformers_version": "4.27.4",
31
  "use_cache": true,
32
  "vocab_size": 65024
33
+ }
configuration_RW.py CHANGED
@@ -1,5 +1,5 @@
1
  # coding=utf-8
2
- # Copyright 2022 the Big Science Workshop and HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
@@ -12,67 +12,135 @@
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
- """ Bloom configuration"""
16
  from transformers.configuration_utils import PretrainedConfig
17
  from transformers.utils import logging
18
 
19
 
20
  logger = logging.get_logger(__name__)
21
 
 
 
 
 
 
22
 
23
  class RWConfig(PretrainedConfig):
24
- model_type = "RefinedWebModel"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  keys_to_ignore_at_inference = ["past_key_values"]
26
- attribute_map = {
27
- "num_hidden_layers": "n_layer",
28
- "num_attention_heads": "n_head",
29
- }
30
 
31
  def __init__(
32
  self,
33
- vocab_size=250880,
34
- hidden_size=64,
35
- n_layer=2,
36
- n_head=8,
37
  layer_norm_epsilon=1e-5,
38
  initializer_range=0.02,
39
  use_cache=True,
40
- bos_token_id=1,
41
- eos_token_id=2,
42
- apply_residual_connection_post_layernorm=False,
43
  hidden_dropout=0.0,
44
  attention_dropout=0.0,
45
- multi_query=False,
46
  alibi=False,
 
 
 
47
  bias=False,
48
- parallel_attn=False,
 
49
  **kwargs,
50
  ):
51
  self.vocab_size = vocab_size
52
  # Backward compatibility with n_embed kwarg
53
  n_embed = kwargs.pop("n_embed", None)
54
  self.hidden_size = hidden_size if n_embed is None else n_embed
55
- self.n_layer = n_layer
56
- self.n_head = n_head
57
  self.layer_norm_epsilon = layer_norm_epsilon
58
  self.initializer_range = initializer_range
59
  self.use_cache = use_cache
60
- self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
61
  self.hidden_dropout = hidden_dropout
62
  self.attention_dropout = attention_dropout
63
 
64
  self.bos_token_id = bos_token_id
65
  self.eos_token_id = eos_token_id
66
- self.multi_query = multi_query
67
  self.alibi = alibi
68
- self.bias = bias
 
69
  self.parallel_attn = parallel_attn
 
70
 
71
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
72
 
73
  @property
74
  def head_dim(self):
75
- return self.hidden_size // self.n_head
76
 
77
  @property
78
  def rotary(self):
 
1
  # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
  #
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
  # you may not use this file except in compliance with the License.
 
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
+ """ Falcon configuration"""
16
  from transformers.configuration_utils import PretrainedConfig
17
  from transformers.utils import logging
18
 
19
 
20
  logger = logging.get_logger(__name__)
21
 
22
+ FALCON_PRETRAINED_CONFIG_ARCHIVE_MAP = {
23
+ "tiiuae/falcon-40b": "https://huggingface.co/tiiuae/falcon-40b/resolve/main/config.json",
24
+ "tiiuae/falcon-7b": "https://huggingface.co/tiiuae/falcon-7b/resolve/main/config.json",
25
+ }
26
+
27
 
28
  class RWConfig(PretrainedConfig):
29
+ r"""
30
+ This is the configuration class to store the configuration of a [`FalconModel`]. It is used to instantiate a Falcon
31
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
32
+ defaults will yield a similar configuration to that of the
33
+ [tiiuae/falcon-7b](https://huggingface.co/tiiuae/falcon-7b) architecture.
34
+
35
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
36
+ documentation from [`PretrainedConfig`] for more information.
37
+
38
+
39
+ Args:
40
+ vocab_size (`int`, *optional*, defaults to 65024):
41
+ Vocabulary size of the Falcon model. Defines the number of different tokens that can be represented by the
42
+ `inputs_ids` passed when calling [`FalconModel`]
43
+ hidden_size (`int`, *optional*, defaults to 4544):
44
+ Dimension of the hidden representations.
45
+ num_hidden_layers (`int`, *optional*, defaults to 32):
46
+ Number of hidden layers in the Transformer decoder.
47
+ num_attention_heads (`int`, *optional*, defaults to 71):
48
+ Number of attention heads for each attention layer in the Transformer encoder.
49
+ initializer_range (`float`, *optional*, defaults to 0.02):
50
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
51
+ use_cache (`bool`, *optional*, defaults to `True`):
52
+ Whether the model should return the last key/values attentions (not used by all models). Only relevant if
53
+ `config.is_decoder=True`.
54
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
55
+ The epsilon used by the layer normalization layers.
56
+ hidden_dropout (`float`, *optional*, defaults to 0.0):
57
+ The dropout probability for MLP layers.
58
+ attention_dropout (`float`, *optional*, defaults to 0.0):
59
+ The dropout probability for attention layers.
60
+ num_kv_heads (`int`, *optional*):
61
+ Number of key-value heads to use per attention layer. If unset, defaults to the same value as
62
+ `num_attention_heads`.
63
+ alibi (`bool`, *optional*, defaults to `False`):
64
+ Whether to use ALiBi positional biases during self-attention.
65
+ new_decoder_architecture (`bool`, *optional*, defaults to `False`):
66
+ Whether to use the new (Falcon-40B) decoder architecture. If `True`, the `multi_query` and `parallel_attn`
67
+ arguments are ignored, as the new decoder always uses parallel attention.
68
+ multi_query (`bool`, *optional*, defaults to `True`):
69
+ Whether to use multi-query attention in the decoder. Ignored when `new_decoder_architecture` is `True`.
70
+ parallel_attn (`bool`, *optional*, defaults to `True`):
71
+ Whether to compute attention in parallel with the feedforward layer. If False, they are consecutive
72
+ instead, as in the original Transformer architecture. Ignored when `new_decoder_architecture` is `True`.
73
+ bias (`bool`, *optional*, defaults to `False`):
74
+ Whether to use bias on Linear layers.
75
+ bos_token_id (`int`, *optional*, defaults to 11):
76
+ The id of the "beginning-of-sequence" token.
77
+ eos_token_id (`int`, *optional*, defaults to 11):
78
+ The id of the "end-of-sequence" token.
79
+
80
+ Example:
81
+
82
+ ```python
83
+ >>> from transformers import FalconModel, RWConfig
84
+
85
+ >>> # Initializing a small (2-layer) Falcon configuration
86
+ >>> configuration = RWConfig(num_hidden_layers=2)
87
+
88
+ >>> # Initializing a model from the small configuration
89
+ >>> model = FalconModel(configuration)
90
+
91
+ >>> # Accessing the model configuration
92
+ >>> configuration = model.config
93
+ ```"""
94
+ model_type = "falcon"
95
  keys_to_ignore_at_inference = ["past_key_values"]
 
 
 
 
96
 
97
  def __init__(
98
  self,
99
+ vocab_size=65024,
100
+ hidden_size=4544,
101
+ num_hidden_layers=32,
102
+ num_attention_heads=71,
103
  layer_norm_epsilon=1e-5,
104
  initializer_range=0.02,
105
  use_cache=True,
 
 
 
106
  hidden_dropout=0.0,
107
  attention_dropout=0.0,
108
+ num_kv_heads=None,
109
  alibi=False,
110
+ new_decoder_architecture=False,
111
+ multi_query=True,
112
+ parallel_attn=True,
113
  bias=False,
114
+ bos_token_id=11,
115
+ eos_token_id=11,
116
  **kwargs,
117
  ):
118
  self.vocab_size = vocab_size
119
  # Backward compatibility with n_embed kwarg
120
  n_embed = kwargs.pop("n_embed", None)
121
  self.hidden_size = hidden_size if n_embed is None else n_embed
122
+ self.num_hidden_layers = num_hidden_layers
123
+ self.num_attention_heads = num_attention_heads
124
  self.layer_norm_epsilon = layer_norm_epsilon
125
  self.initializer_range = initializer_range
126
  self.use_cache = use_cache
 
127
  self.hidden_dropout = hidden_dropout
128
  self.attention_dropout = attention_dropout
129
 
130
  self.bos_token_id = bos_token_id
131
  self.eos_token_id = eos_token_id
132
+ self.num_kv_heads = num_attention_heads if num_kv_heads is None else num_kv_heads
133
  self.alibi = alibi
134
+ self.new_decoder_architecture = new_decoder_architecture
135
+ self.multi_query = multi_query # Ignored when new_decoder_architecture is True
136
  self.parallel_attn = parallel_attn
137
+ self.bias = bias
138
 
139
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
140
 
141
  @property
142
  def head_dim(self):
143
+ return self.hidden_size // self.num_attention_heads
144
 
145
  @property
146
  def rotary(self):
modelling_RW.py → modeling_RW.py RENAMED
@@ -1,9 +1,20 @@
1
- # port of models described in RW
2
- # We use the bloom model as a starting point for these model.
3
- # Please refer to the bloom models for usage instructions.
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  import math
6
- import warnings
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
@@ -20,59 +31,60 @@ from transformers.modeling_outputs import (
20
  TokenClassifierOutput,
21
  )
22
  from transformers.modeling_utils import PreTrainedModel
23
- from transformers.utils import logging
24
  from .configuration_RW import RWConfig
25
 
 
26
  logger = logging.get_logger(__name__)
27
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
29
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
30
- class Linear(nn.Linear):
31
  def forward(self, input: torch.Tensor) -> torch.Tensor:
32
- ret = input @ self.weight.T
33
  if self.bias is None:
34
- return ret
35
- else:
36
- return ret + self.bias
37
-
38
 
39
- from einops import rearrange
40
 
41
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
42
  def rotate_half(x):
43
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
44
- return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
45
 
46
 
47
- class RotaryEmbedding(torch.nn.Module):
48
  """Implementation of RotaryEmbedding from GPT-NeoX.
49
- This implementation is design to operate on queries and keys that are compatible with
50
- [batch_size, n_heads_per_partition, seq_len, head_dim] (e.g. MinGPTAttention format).
51
  """
52
 
53
- def __init__(
54
- self,
55
- head_dim: int,
56
- base=10000,
57
- ):
58
  super().__init__()
59
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
60
  self.register_buffer("inv_freq", inv_freq, persistent=False)
61
  self.head_dim = head_dim
62
- self.seq_len_cached = None
63
- self.batch_size_cached = None
64
  self.cos_cached: torch.Tensor | None = None
65
  self.sin_cached: torch.Tensor | None = None
66
 
67
- def cos_sin(
68
- self,
69
- seq_len: int,
70
- device="cuda",
71
- dtype=torch.bfloat16,
72
- ) -> torch.Tensor:
73
- if seq_len != self.seq_len_cached:
74
- self.seq_len_cached = seq_len
75
- t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
76
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
77
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
78
 
@@ -85,36 +97,46 @@ class RotaryEmbedding(torch.nn.Module):
85
  self.cos_cached = self.cos_cached.type(dtype)
86
  self.sin_cached = self.sin_cached.type(dtype)
87
 
88
- return self.cos_cached, self.sin_cached
 
 
 
89
 
90
- def forward(self, q, k):
91
- batch, seq_len, head_dim = q.shape
92
- cos, sin = self.cos_sin(seq_len, q.device, q.dtype)
93
- return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)
94
 
95
 
96
  def _make_causal_mask(
97
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
98
  ) -> torch.BoolTensor:
 
 
 
 
 
99
  batch_size, target_length = input_ids_shape
100
- mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
101
- # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
102
- seq_ids = torch.arange(target_length, device=device)
103
- mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
104
-
105
- if past_key_values_length > 0:
106
- mask[:, :past_key_values_length] = False
107
 
 
 
 
 
 
 
108
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
109
  return expanded_mask
110
 
111
 
112
- def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
113
- batch_size, src_length = mask.shape
114
- tgt_length = tgt_length if tgt_length is not None else src_length
 
 
 
115
 
116
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
117
- return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
118
 
119
 
120
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
@@ -145,18 +167,32 @@ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torc
145
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
146
 
147
 
 
148
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  out = F.dropout(x, p=prob, training=training)
150
  out = residual + out
151
  return out
152
 
153
 
154
- class Attention(nn.Module):
155
  def __init__(self, config: RWConfig):
156
  super().__init__()
157
 
158
  self.hidden_size = config.hidden_size
159
- self.num_heads = config.n_head
160
  self.head_dim = self.hidden_size // self.num_heads
161
  self.split_size = self.hidden_size
162
  self.hidden_dropout = config.hidden_dropout
@@ -167,26 +203,27 @@ class Attention(nn.Module):
167
  f" {self.num_heads})."
168
  )
169
 
170
- self.maybe_rotary = RotaryEmbedding(config.head_dim) if config.rotary else lambda q, k: (q, k)
171
 
172
  # Layer-wise attention scaling
173
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
174
  self.beta = self.inv_norm_factor
175
-
176
- self.query_key_value = Linear(
177
- self.hidden_size,
178
- 3 * self.hidden_size if not config.multi_query else (self.hidden_size + 2 * self.head_dim),
179
- bias=config.bias,
180
- )
 
 
181
  self.multi_query = config.multi_query
182
- self.dense = Linear(self.hidden_size, self.hidden_size, bias=config.bias)
183
  self.attention_dropout = nn.Dropout(config.attention_dropout)
184
- self.num_kv = config.n_head if not self.multi_query else 1
185
 
186
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
187
  """
188
- Split the last dimension into (num_heads, head_dim) without making any copies, results share same memory
189
- storage as `fused_qkv`
190
 
191
  Args:
192
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
@@ -195,7 +232,18 @@ class Attention(nn.Module):
195
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
196
  value: [batch_size, seq_length, num_heads, head_dim]
197
  """
198
- if not self.multi_query:
 
 
 
 
 
 
 
 
 
 
 
199
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
200
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
201
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
@@ -204,12 +252,13 @@ class Attention(nn.Module):
204
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
205
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
206
 
 
207
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
208
  """
209
  Merge heads together over the last dimenstion
210
 
211
  Args:
212
- x: (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
213
 
214
  Returns:
215
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
@@ -232,7 +281,7 @@ class Attention(nn.Module):
232
  def forward(
233
  self,
234
  hidden_states: torch.Tensor,
235
- alibi: torch.Tensor,
236
  attention_mask: torch.Tensor,
237
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
238
  head_mask: Optional[torch.Tensor] = None,
@@ -240,105 +289,120 @@ class Attention(nn.Module):
240
  output_attentions: bool = False,
241
  ):
242
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
243
-
244
  # 3 x [batch_size, seq_length, num_heads, head_dim]
245
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
246
 
247
- batch_size, q_length, _, _ = query_layer.shape
248
 
249
- query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, q_length, self.head_dim)
250
  key_layer = key_layer.transpose(1, 2).reshape(
251
- batch_size * self.num_kv,
252
- q_length,
253
  self.head_dim,
254
  )
255
- value_layer = value_layer.transpose(1, 2).reshape(batch_size * self.num_kv, q_length, self.head_dim)
256
 
257
- query_layer, key_layer = self.maybe_rotary(query_layer, key_layer)
 
258
 
259
  if layer_past is not None:
260
  past_key, past_value = layer_past
261
  # concatenate along seq_length dimension:
262
- # - key: [batch_size * self.num_heads, head_dim, kv_length]
263
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
264
  key_layer = torch.cat((past_key, key_layer), dim=1)
265
  value_layer = torch.cat((past_value, value_layer), dim=1)
266
 
267
  _, kv_length, _ = key_layer.shape
268
-
269
- if use_cache is True:
270
  present = (key_layer, value_layer)
271
  else:
272
  present = None
273
 
 
 
 
 
 
 
274
  if alibi is None:
275
- query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
276
- key_layer_ = key_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
277
- value_layer_ = value_layer.reshape(batch_size, self.num_kv, -1, self.head_dim)
 
 
278
 
279
- attn_output = F.scaled_dot_product_attention(
280
- query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
281
- )
 
 
 
 
 
 
282
 
283
- x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
284
- x = x.permute(0, 2, 1, 3)
285
- attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
286
 
287
  output_tensor = self.dense(attn_output)
288
 
289
- outputs = (output_tensor, present)
290
- assert not output_attentions # not supported.
291
- return outputs
 
 
292
  else:
293
- attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
294
- matmul_result = query_layer @ key_layer.transpose(-1, -2)
295
 
296
  # change view to [batch_size, num_heads, q_length, kv_length]
297
- attention_scores = matmul_result.view(batch_size, self.num_heads, q_length, kv_length)
298
 
299
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
300
  input_dtype = attention_scores.dtype
301
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
302
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
303
  attention_scores = attention_scores.to(torch.float32)
304
- # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
305
- attention_probs = F.softmax(
306
- (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) * self.inv_norm_factor + attention_mask_float,
307
- dim=-1,
308
- dtype=hidden_states.dtype,
309
- )
 
310
  # [batch_size, num_heads, q_length, kv_length]
311
  attention_probs = self.attention_dropout(attention_probs)
312
 
313
  if head_mask is not None:
314
  attention_probs = attention_probs * head_mask
315
 
316
- # change view [batch_size x num_heads, q_length, kv_length]
317
- attention_probs_reshaped = attention_probs.view(batch_size * self.num_heads, q_length, kv_length)
318
 
319
  # matmul: [batch_size * num_heads, q_length, head_dim]
320
- context_layer = attention_probs_reshaped @ value_layer
321
 
322
  # change view [batch_size, num_heads, q_length, head_dim]
323
  context_layer = self._merge_heads(context_layer)
324
 
325
  output_tensor = self.dense(context_layer)
326
 
327
- outputs = (output_tensor, present)
328
  if output_attentions:
329
- outputs += (attention_probs,)
330
-
331
- return outputs
332
 
333
 
334
- class MLP(nn.Module):
335
  def __init__(self, config: RWConfig):
336
  super().__init__()
337
  hidden_size = config.hidden_size
338
 
339
- self.dense_h_to_4h = Linear(hidden_size, 4 * hidden_size, bias=config.bias)
340
  self.act = nn.GELU()
341
- self.dense_4h_to_h = Linear(4 * hidden_size, hidden_size, bias=config.bias)
342
  self.hidden_dropout = config.hidden_dropout
343
 
344
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -347,43 +411,47 @@ class MLP(nn.Module):
347
  return x
348
 
349
 
350
- class DecoderLayer(nn.Module):
351
  def __init__(self, config: RWConfig):
352
  super().__init__()
353
  hidden_size = config.hidden_size
354
-
355
- self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
356
- self.num_heads = config.n_head
357
- self.self_attention = Attention(config)
358
-
359
- if not config.parallel_attn:
360
- # unused if parallel attn
361
- self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
362
-
363
- self.mlp = MLP(config)
364
-
365
- self.apply_residual_connection_post_layernorm = config.apply_residual_connection_post_layernorm
366
  self.hidden_dropout = config.hidden_dropout
367
-
368
  self.config = config
369
 
 
 
 
 
 
 
 
 
 
 
370
  def forward(
371
  self,
372
  hidden_states: torch.Tensor,
373
- alibi: torch.Tensor,
374
  attention_mask: torch.Tensor,
375
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
376
  head_mask: Optional[torch.Tensor] = None,
377
  use_cache: bool = False,
378
  output_attentions: bool = False,
379
  ):
380
-
381
- layernorm_output = self.input_layernorm(hidden_states)
382
  residual = hidden_states
383
 
 
 
 
 
 
 
384
  # Self attention.
385
  attn_outputs = self.self_attention(
386
- layernorm_output,
387
  layer_past=layer_past,
388
  attention_mask=attention_mask,
389
  alibi=alibi,
@@ -394,16 +462,21 @@ class DecoderLayer(nn.Module):
394
 
395
  attention_output = attn_outputs[0]
396
 
397
- if not self.config.parallel_attn:
398
- residual = dropout_add(attention_output, residual, self.config.attention_dropout, training=self.training)
399
- layernorm_output = self.post_attention_layernorm(residual)
 
 
 
 
 
400
 
401
  outputs = attn_outputs[1:]
402
 
403
  # MLP.
404
- mlp_output = self.mlp(layernorm_output)
405
 
406
- if self.config.parallel_attn:
407
  mlp_output += attention_output
408
 
409
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
@@ -416,8 +489,77 @@ class DecoderLayer(nn.Module):
416
  return outputs # hidden_states, present, attentions
417
 
418
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  class RWPreTrainedModel(PreTrainedModel):
420
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
421
  """
422
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
423
  models.
@@ -426,14 +568,14 @@ class RWPreTrainedModel(PreTrainedModel):
426
  config_class = RWConfig
427
  base_model_prefix = "transformer"
428
  supports_gradient_checkpointing = True
429
- _no_split_modules = ["DecoderLayer"]
430
 
431
  def __init__(self, *inputs, **kwargs):
432
  super().__init__(*inputs, **kwargs)
433
 
434
  def _init_weights(self, module: nn.Module):
435
  """Initialize the weights."""
436
- if isinstance(module, nn.Linear) or isinstance(module, Linear):
437
  # Slightly different from the TF version which uses truncated_normal for initialization
438
  # cf https://github.com/pytorch/pytorch/pull/5617
439
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
@@ -447,26 +589,28 @@ class RWPreTrainedModel(PreTrainedModel):
447
  module.bias.data.zero_()
448
  module.weight.data.fill_(1.0)
449
 
 
450
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
451
  if isinstance(module, RWModel):
452
  module.gradient_checkpointing = value
453
 
454
  @staticmethod
455
- def _convert_to_standard_cache(
456
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
457
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
458
  """
459
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
460
  num_heads, ...]))
461
  """
462
- batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
 
 
 
463
  num_heads = batch_size_times_num_heads // batch_size
464
- # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
465
- # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
466
  return tuple(
467
  (
468
- layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
469
- layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
470
  )
471
  for layer_past in past_key_value
472
  )
@@ -475,32 +619,35 @@ class RWPreTrainedModel(PreTrainedModel):
475
  def _convert_to_rw_cache(
476
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
477
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
478
- batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
479
  batch_size_times_num_heads = batch_size * num_heads
480
- # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
481
- # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
482
  return tuple(
483
  (
484
- layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
485
- layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
486
  )
487
  for layer_past in past_key_value
488
  )
489
 
490
 
 
 
 
 
491
  class RWModel(RWPreTrainedModel):
492
  def __init__(self, config: RWConfig):
493
  super().__init__(config)
494
 
495
  self.embed_dim = config.hidden_size
496
- self.num_heads = config.n_head
497
- self.alibi = config.alibi
498
 
499
  # Embedding + LN Embedding
500
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
501
 
502
  # Transformer blocks
503
- self.h = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
504
 
505
  # Final Layer Norm
506
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
@@ -513,22 +660,31 @@ class RWModel(RWPreTrainedModel):
513
  def get_input_embeddings(self):
514
  return self.word_embeddings
515
 
 
516
  def _prepare_attn_mask(
517
- self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
518
  ) -> torch.BoolTensor:
519
- # create causal mask
520
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
 
 
 
 
 
 
 
 
521
  combined_attention_mask = None
522
  device = attention_mask.device
523
- _, src_length = input_shape
524
 
525
- if src_length > 1:
526
  combined_attention_mask = _make_causal_mask(
527
  input_shape, device=device, past_key_values_length=past_key_values_length
528
  )
529
 
530
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
531
- expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
532
  combined_attention_mask = (
533
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
534
  )
@@ -538,6 +694,12 @@ class RWModel(RWPreTrainedModel):
538
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
539
  self.word_embeddings = new_embeddings
540
 
 
 
 
 
 
 
541
  def forward(
542
  self,
543
  input_ids: Optional[torch.LongTensor] = None,
@@ -549,18 +711,7 @@ class RWModel(RWPreTrainedModel):
549
  output_attentions: Optional[bool] = None,
550
  output_hidden_states: Optional[bool] = None,
551
  return_dict: Optional[bool] = None,
552
- **deprecated_arguments,
553
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
554
- if deprecated_arguments.pop("position_ids", False) is not False:
555
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
556
- warnings.warn(
557
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
558
- " passing `position_ids`.",
559
- FutureWarning,
560
- )
561
- if len(deprecated_arguments) > 0:
562
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
563
-
564
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
565
  output_hidden_states = (
566
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -579,12 +730,14 @@ class RWModel(RWPreTrainedModel):
579
 
580
  if past_key_values is None:
581
  past_key_values = tuple([None] * len(self.h))
 
 
582
 
583
  # Prepare head mask if needed
584
  # 1.0 in head_mask indicate we keep the head
585
  # attention_probs has shape batch_size x num_heads x N x N
586
  # head_mask has shape n_layer x batch x num_heads x N x N
587
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
588
 
589
  if inputs_embeds is None:
590
  inputs_embeds = self.word_embeddings(input_ids)
@@ -596,17 +749,15 @@ class RWModel(RWPreTrainedModel):
596
  all_hidden_states = () if output_hidden_states else None
597
 
598
  # Compute alibi tensor: check build_alibi_tensor documentation
599
- seq_length_with_past = seq_length
600
  past_key_values_length = 0
601
  if past_key_values[0] is not None:
602
- past_key_values_length = past_key_values[0][0].shape[2]
603
- seq_length_with_past = seq_length_with_past + past_key_values_length
604
  if attention_mask is None:
605
- attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
606
  else:
607
  attention_mask = attention_mask.to(hidden_states.device)
608
 
609
- if self.alibi:
610
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
611
  else:
612
  alibi = None
@@ -618,12 +769,10 @@ class RWModel(RWPreTrainedModel):
618
  )
619
 
620
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
621
-
622
  if output_hidden_states:
623
  all_hidden_states = all_hidden_states + (hidden_states,)
624
 
625
  if self.gradient_checkpointing and self.training:
626
-
627
  if use_cache:
628
  logger.warning(
629
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
@@ -668,6 +817,9 @@ class RWModel(RWPreTrainedModel):
668
  if output_hidden_states:
669
  all_hidden_states = all_hidden_states + (hidden_states,)
670
 
 
 
 
671
  if not return_dict:
672
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
673
 
@@ -679,8 +831,12 @@ class RWModel(RWPreTrainedModel):
679
  )
680
 
681
 
 
 
 
 
682
  class RWForCausalLM(RWPreTrainedModel):
683
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
684
 
685
  def __init__(self, config: RWConfig):
686
  super().__init__(config)
@@ -699,25 +855,26 @@ class RWForCausalLM(RWPreTrainedModel):
699
  def prepare_inputs_for_generation(
700
  self,
701
  input_ids: torch.LongTensor,
702
- past: Optional[torch.Tensor] = None,
703
  attention_mask: Optional[torch.Tensor] = None,
704
  **kwargs,
705
  ) -> dict:
706
- # only last token for input_ids if past is not None
707
- if past:
708
- input_ids = input_ids[:, -1].unsqueeze(-1)
709
-
710
- # the cache may be in the stardard format (e.g. in contrastive search), convert to our's format if needed
711
- if past[0][0].shape[0] == input_ids.shape[0]:
712
- past = self._convert_to_rw_cache(past)
713
 
714
  return {
715
  "input_ids": input_ids,
716
- "past_key_values": past,
717
  "use_cache": kwargs.get("use_cache"),
718
  "attention_mask": attention_mask,
719
  }
720
 
 
 
 
 
 
 
721
  def forward(
722
  self,
723
  input_ids: Optional[torch.LongTensor] = None,
@@ -730,7 +887,6 @@ class RWForCausalLM(RWPreTrainedModel):
730
  output_attentions: Optional[bool] = None,
731
  output_hidden_states: Optional[bool] = None,
732
  return_dict: Optional[bool] = None,
733
- **deprecated_arguments,
734
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
735
  r"""
736
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -738,15 +894,6 @@ class RWForCausalLM(RWPreTrainedModel):
738
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
739
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
740
  """
741
- if deprecated_arguments.pop("position_ids", False) is not False:
742
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
743
- warnings.warn(
744
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
745
- " passing `position_ids`.",
746
- FutureWarning,
747
- )
748
- if len(deprecated_arguments) > 0:
749
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
750
 
751
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
752
 
@@ -799,7 +946,6 @@ class RWForCausalLM(RWPreTrainedModel):
799
 
800
  Output shares the same memory storage as `past`.
801
  """
802
- standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
803
 
804
  # Get a copy of `beam_idx` on all the devices where we need those indices.
805
  device_to_beam_idx = {
@@ -810,14 +956,27 @@ class RWForCausalLM(RWPreTrainedModel):
810
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
811
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
812
  )
813
- for layer_past in standardized_past
814
  )
815
- return self._convert_to_rw_cache(reordered_past)
816
 
817
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818
  class RWForSequenceClassification(RWPreTrainedModel):
819
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
820
-
821
  def __init__(self, config: RWConfig):
822
  super().__init__(config)
823
  self.num_labels = config.num_labels
@@ -827,6 +986,12 @@ class RWForSequenceClassification(RWPreTrainedModel):
827
  # Initialize weights and apply final processing
828
  self.post_init()
829
 
 
 
 
 
 
 
830
  def forward(
831
  self,
832
  input_ids: Optional[torch.LongTensor] = None,
@@ -839,7 +1004,6 @@ class RWForSequenceClassification(RWPreTrainedModel):
839
  output_attentions: Optional[bool] = None,
840
  output_hidden_states: Optional[bool] = None,
841
  return_dict: Optional[bool] = None,
842
- **deprecated_arguments,
843
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
844
  r"""
845
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -847,15 +1011,6 @@ class RWForSequenceClassification(RWPreTrainedModel):
847
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
848
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
849
  """
850
- if deprecated_arguments.pop("position_ids", False) is not False:
851
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
852
- warnings.warn(
853
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
854
- " passing `position_ids`.",
855
- FutureWarning,
856
- )
857
- if len(deprecated_arguments) > 0:
858
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
859
 
860
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
861
 
@@ -930,17 +1085,22 @@ class RWForSequenceClassification(RWPreTrainedModel):
930
  )
931
 
932
 
 
 
 
 
 
 
 
933
  class RWForTokenClassification(RWPreTrainedModel):
934
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
935
-
936
  def __init__(self, config: RWConfig):
937
  super().__init__(config)
938
  self.num_labels = config.num_labels
939
 
940
  self.transformer = RWModel(config)
941
- if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
942
  classifier_dropout = config.classifier_dropout
943
- elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None:
944
  classifier_dropout = config.hidden_dropout
945
  else:
946
  classifier_dropout = 0.1
@@ -950,6 +1110,12 @@ class RWForTokenClassification(RWPreTrainedModel):
950
  # Initialize weights and apply final processing
951
  self.post_init()
952
 
 
 
 
 
 
 
953
  def forward(
954
  self,
955
  input_ids: Optional[torch.LongTensor] = None,
@@ -962,7 +1128,6 @@ class RWForTokenClassification(RWPreTrainedModel):
962
  output_attentions: Optional[bool] = None,
963
  output_hidden_states: Optional[bool] = None,
964
  return_dict: Optional[bool] = None,
965
- **deprecated_arguments,
966
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
967
  r"""
968
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -970,15 +1135,6 @@ class RWForTokenClassification(RWPreTrainedModel):
970
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
971
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
972
  """
973
- if deprecated_arguments.pop("position_ids", False) is not False:
974
- # `position_ids` could have been `torch.Tensor` or `None` so defaulting pop to `False` allows to detect if users were passing explicitly `None`
975
- warnings.warn(
976
- "`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. You can safely ignore"
977
- " passing `position_ids`.",
978
- FutureWarning,
979
- )
980
- if len(deprecated_arguments) > 0:
981
- raise ValueError(f"Got unexpected arguments: {deprecated_arguments}")
982
 
983
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
984
 
@@ -1002,7 +1158,9 @@ class RWForTokenClassification(RWPreTrainedModel):
1002
  if labels is not None:
1003
  batch_size, seq_length = labels.shape
1004
  loss_fct = CrossEntropyLoss()
1005
- loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length))
 
 
1006
 
1007
  if not return_dict:
1008
  output = (logits,) + transformer_outputs[2:]
@@ -1016,9 +1174,14 @@ class RWForTokenClassification(RWPreTrainedModel):
1016
  )
1017
 
1018
 
 
 
 
 
 
 
 
1019
  class RWForQuestionAnswering(RWPreTrainedModel):
1020
- _keys_to_ignore_on_load_missing = [r"h.*.self_attention.scale_mask_softmax.causal_mask", r"lm_head.weight"]
1021
-
1022
  def __init__(self, config):
1023
  super().__init__(config)
1024
  self.transformer = RWModel(config)
@@ -1027,11 +1190,11 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1027
  # Initialize weights and apply final processing
1028
  self.post_init()
1029
 
 
1030
  def forward(
1031
  self,
1032
  input_ids: Optional[torch.LongTensor] = None,
1033
  attention_mask: Optional[torch.FloatTensor] = None,
1034
- position_ids: Optional[torch.LongTensor] = None,
1035
  head_mask: Optional[torch.FloatTensor] = None,
1036
  inputs_embeds: Optional[torch.FloatTensor] = None,
1037
  start_positions: Optional[torch.LongTensor] = None,
@@ -1055,7 +1218,6 @@ class RWForQuestionAnswering(RWPreTrainedModel):
1055
  outputs = self.transformer(
1056
  input_ids,
1057
  attention_mask=attention_mask,
1058
- position_ids=position_ids,
1059
  head_mask=head_mask,
1060
  inputs_embeds=inputs_embeds,
1061
  output_attentions=output_attentions,
 
1
+ # coding=utf-8
2
+ # Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """PyTorch Falcon model."""
16
 
17
  import math
 
18
  from typing import Optional, Tuple, Union
19
 
20
  import torch
 
31
  TokenClassifierOutput,
32
  )
33
  from transformers.modeling_utils import PreTrainedModel
34
+ from transformers.utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging
35
  from .configuration_RW import RWConfig
36
 
37
+
38
  logger = logging.get_logger(__name__)
39
 
40
+ FALCON_PRETRAINED_MODEL_ARCHIVE_LIST = [
41
+ "tiiuae/falcon-40b",
42
+ "tiiuae/falcon-40b-instruct",
43
+ "tiiuae/falcon-7b",
44
+ "tiiuae/falcon-7b-instruct",
45
+ "tiiuae/falcon-rw-7b",
46
+ "tiiuae/falcon-rw-1b",
47
+ ]
48
+ _CHECKPOINT_FOR_DOC = "Rocketknight1/falcon-rw-1b"
49
+ _CONFIG_FOR_DOC = "RWConfig"
50
+
51
+
52
  # NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during training, this means that there's one additional quantization to bfloat16 between the operations.
53
  # In order not to degrade the quality of our HF-port, we keep these characteristics in the final model.
54
+ class FalconLinear(nn.Linear):
55
  def forward(self, input: torch.Tensor) -> torch.Tensor:
56
+ hidden_states = input @ self.weight.T
57
  if self.bias is None:
58
+ return hidden_states
59
+ return hidden_states + self.bias
 
 
60
 
 
61
 
62
  # rotary pos emb helpers (torch.jit.script does not seem to support staticmethod...)
63
  def rotate_half(x):
64
  x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
65
+ return torch.cat((-x2, x1), dim=-1)
66
 
67
 
68
+ class FalconRotaryEmbedding(nn.Module):
69
  """Implementation of RotaryEmbedding from GPT-NeoX.
70
+ This implementation is designed to operate on queries and keys that are compatible with `[batch_size,
71
+ n_heads_per_partition, seq_len, head_dim]` (e.g. MinGPTAttention format).
72
  """
73
 
74
+ def __init__(self, head_dim: int, base=10000):
 
 
 
 
75
  super().__init__()
76
  inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2).float() / head_dim))
77
  self.register_buffer("inv_freq", inv_freq, persistent=False)
78
  self.head_dim = head_dim
79
+ self.seq_len_cached = -1
 
80
  self.cos_cached: torch.Tensor | None = None
81
  self.sin_cached: torch.Tensor | None = None
82
 
83
+ def cos_sin(self, seq_len: int, past_key_values_length: int, device="cpu", dtype=torch.bfloat16) -> torch.Tensor:
84
+ total_length = seq_len + past_key_values_length
85
+ if total_length > self.seq_len_cached:
86
+ self.seq_len_cached = total_length
87
+ t = torch.arange(total_length, device=device, dtype=self.inv_freq.dtype)
 
 
 
 
88
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
89
  emb = torch.cat((freqs, freqs), dim=-1).to(device)
90
 
 
97
  self.cos_cached = self.cos_cached.type(dtype)
98
  self.sin_cached = self.sin_cached.type(dtype)
99
 
100
+ return (
101
+ self.cos_cached[:, past_key_values_length : seq_len + past_key_values_length],
102
+ self.sin_cached[:, past_key_values_length : seq_len + past_key_values_length],
103
+ )
104
 
105
+ def forward(self, query, key, past_key_values_length=0):
106
+ batch, seq_len, head_dim = query.shape
107
+ cos, sin = self.cos_sin(seq_len, past_key_values_length, query.device, query.dtype)
108
+ return (query * cos) + (rotate_half(query) * sin), (key * cos) + (rotate_half(key) * sin)
109
 
110
 
111
  def _make_causal_mask(
112
  input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
113
  ) -> torch.BoolTensor:
114
+ """
115
+ Make causal mask used for self-attention. This mask does not take the existing attention mask into account - it
116
+ just blocks tokens from attending forwards in the sequence. The output shape will be `[batch_size, 1,
117
+ target_length, target_length+past_key_values_length]`.
118
+ """
119
  batch_size, target_length = input_ids_shape
 
 
 
 
 
 
 
120
 
121
+ mask = torch.triu(torch.ones((target_length, target_length), dtype=torch.bool, device=device), diagonal=1)
122
+ # If past_key_values_length is 0 this is an empty tensor and the concatenation is a no-op.
123
+ # This code style is an unfortunate consequence of getting your TF engineer to port models; doing it this
124
+ # way avoids a data-dependent conditional, which will help me when I have to port this to XLA later.
125
+ past_mask = torch.zeros((target_length, past_key_values_length), dtype=torch.bool, device=device)
126
+ mask = torch.cat([past_mask, mask], dim=-1)
127
  expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
128
  return expanded_mask
129
 
130
 
131
+ def _expand_mask(mask: torch.Tensor, past_key_values_length: int) -> torch.BoolTensor:
132
+ """
133
+ Expands attention_mask from `[batch_size, seq_length]` to `[batch_size, 1, seq_length, seq_length + past_length]`.
134
+ """
135
+ batch_size, total_length = mask.shape
136
+ seq_length = total_length - past_key_values_length if past_key_values_length is not None else total_length
137
 
138
  expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
139
+ return expanded_mask.expand(batch_size, 1, seq_length, total_length)
140
 
141
 
142
  def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
 
167
  return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
168
 
169
 
170
+ # Copied from transformers.models.bloom.modeling_bloom.dropout_add
171
  def dropout_add(x: torch.Tensor, residual: torch.Tensor, prob: float, training: bool) -> torch.Tensor:
172
+ """
173
+ Dropout add function
174
+
175
+ Args:
176
+ x (`torch.tensor`, *required*):
177
+ input tensor
178
+ residual (`torch.tensor`, *required*):
179
+ residual tensor
180
+ prob (`float`, *required*):
181
+ dropout probability
182
+ training (`bool`, *required*):
183
+ training mode
184
+ """
185
  out = F.dropout(x, p=prob, training=training)
186
  out = residual + out
187
  return out
188
 
189
 
190
+ class FalconAttention(nn.Module):
191
  def __init__(self, config: RWConfig):
192
  super().__init__()
193
 
194
  self.hidden_size = config.hidden_size
195
+ self.num_heads = config.num_attention_heads
196
  self.head_dim = self.hidden_size // self.num_heads
197
  self.split_size = self.hidden_size
198
  self.hidden_dropout = config.hidden_dropout
 
203
  f" {self.num_heads})."
204
  )
205
 
206
+ self.maybe_rotary = FalconRotaryEmbedding(config.head_dim) if config.rotary else lambda q, k, t: (q, k)
207
 
208
  # Layer-wise attention scaling
209
  self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
210
  self.beta = self.inv_norm_factor
211
+ if config.new_decoder_architecture:
212
+ qkv_out_dim = (config.num_kv_heads * 2 + config.num_attention_heads) * self.head_dim
213
+ elif config.multi_query:
214
+ qkv_out_dim = self.hidden_size + 2 * self.head_dim
215
+ else:
216
+ qkv_out_dim = 3 * self.hidden_size
217
+ self.query_key_value = FalconLinear(self.hidden_size, qkv_out_dim, bias=config.bias)
218
+ self.new_decoder_architecture = config.new_decoder_architecture
219
  self.multi_query = config.multi_query
220
+ self.dense = FalconLinear(self.hidden_size, self.hidden_size, bias=config.bias)
221
  self.attention_dropout = nn.Dropout(config.attention_dropout)
222
+ self.num_kv_heads = config.num_kv_heads if (self.new_decoder_architecture or not self.multi_query) else 1
223
 
224
  def _split_heads(self, fused_qkv: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
225
  """
226
+ Split the last dimension into (num_heads, head_dim), results share same memory storage as `fused_qkv`
 
227
 
228
  Args:
229
  fused_qkv (`torch.tensor`, *required*): [batch_size, seq_length, num_heads * 3 * head_dim]
 
232
  query: [batch_size, seq_length, num_heads, head_dim] key: [batch_size, seq_length, num_heads, head_dim]
233
  value: [batch_size, seq_length, num_heads, head_dim]
234
  """
235
+ if self.new_decoder_architecture:
236
+ batch, seq_len, _ = fused_qkv.shape
237
+ qkv = fused_qkv.view(batch, seq_len, -1, self.num_heads // self.num_kv_heads + 2, self.head_dim)
238
+ query = qkv[:, :, :, :-2]
239
+ key = qkv[:, :, :, [-2]]
240
+ value = qkv[:, :, :, [-1]]
241
+ key = torch.broadcast_to(key, query.shape)
242
+ value = torch.broadcast_to(value, query.shape)
243
+
244
+ query, key, value = [x.flatten(2, 3) for x in (query, key, value)]
245
+ return query, key, value
246
+ elif not self.multi_query:
247
  batch_size, seq_length, three_times_hidden_size = fused_qkv.shape
248
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim)
249
  return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :]
 
252
  fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads + 2, self.head_dim)
253
  return fused_qkv[..., :-2, :], fused_qkv[..., [-2], :], fused_qkv[..., [-1], :]
254
 
255
+ # Copied from transformers.models.bloom.modeling_bloom.BloomAttention._merge_heads
256
  def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
257
  """
258
  Merge heads together over the last dimenstion
259
 
260
  Args:
261
+ x (`torch.tensor`, *required*): [batch_size * num_heads, seq_length, head_dim]
262
 
263
  Returns:
264
  torch.tensor: [batch_size, seq_length, num_heads * head_dim]
 
281
  def forward(
282
  self,
283
  hidden_states: torch.Tensor,
284
+ alibi: Optional[torch.Tensor],
285
  attention_mask: torch.Tensor,
286
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
287
  head_mask: Optional[torch.Tensor] = None,
 
289
  output_attentions: bool = False,
290
  ):
291
  fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
292
+ num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads
293
  # 3 x [batch_size, seq_length, num_heads, head_dim]
294
  (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
295
 
296
+ batch_size, query_length, _, _ = query_layer.shape
297
 
298
+ query_layer = query_layer.transpose(1, 2).reshape(batch_size * self.num_heads, query_length, self.head_dim)
299
  key_layer = key_layer.transpose(1, 2).reshape(
300
+ batch_size * num_kv_heads,
301
+ query_length,
302
  self.head_dim,
303
  )
304
+ value_layer = value_layer.transpose(1, 2).reshape(batch_size * num_kv_heads, query_length, self.head_dim)
305
 
306
+ past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
307
+ query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
308
 
309
  if layer_past is not None:
310
  past_key, past_value = layer_past
311
  # concatenate along seq_length dimension:
312
+ # - key: [batch_size * self.num_heads, kv_length, head_dim]
313
  # - value: [batch_size * self.num_heads, kv_length, head_dim]
314
  key_layer = torch.cat((past_key, key_layer), dim=1)
315
  value_layer = torch.cat((past_value, value_layer), dim=1)
316
 
317
  _, kv_length, _ = key_layer.shape
318
+ if use_cache:
 
319
  present = (key_layer, value_layer)
320
  else:
321
  present = None
322
 
323
+ attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
324
+
325
+ query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
326
+ key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
327
+ value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
328
+
329
  if alibi is None:
330
+ if output_attentions:
331
+ # F.scaled_dot_product_attention doesn't return the attention weights, so we have
332
+ # to do it by hand if we want them
333
+ attention_scores = query_layer_ @ key_layer_.transpose(-1, -2)
334
+ attention_scores /= math.sqrt(self.head_dim)
335
 
336
+ attention_scores = F.softmax(
337
+ attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype
338
+ )
339
+ attn_output = attention_scores @ value_layer_
340
+ else:
341
+ attn_output = F.scaled_dot_product_attention(
342
+ query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False
343
+ )
344
+ attention_scores = None
345
 
346
+ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim)
347
+ attn_output = attn_output.permute(0, 2, 1, 3)
348
+ attn_output = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim)
349
 
350
  output_tensor = self.dense(attn_output)
351
 
352
+ if output_attentions:
353
+ return output_tensor, present, attention_scores
354
+ else:
355
+ return output_tensor, present
356
+
357
  else:
358
+ matmul_result = query_layer_ @ key_layer_.transpose(-1, -2)
 
359
 
360
  # change view to [batch_size, num_heads, q_length, kv_length]
361
+ attention_scores = matmul_result.view(batch_size, self.num_heads, query_length, kv_length)
362
 
363
  # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
364
  input_dtype = attention_scores.dtype
365
  # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
366
  if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
367
  attention_scores = attention_scores.to(torch.float32)
368
+ # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by
369
+ # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically
370
+ # equivalent and more performant, but there might be a numerical difference. If you're reading this
371
+ # and you'd like to experiment and maybe file a PR, feel free!
372
+ attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)
373
+ attention_logits *= self.inv_norm_factor
374
+ attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype)
375
  # [batch_size, num_heads, q_length, kv_length]
376
  attention_probs = self.attention_dropout(attention_probs)
377
 
378
  if head_mask is not None:
379
  attention_probs = attention_probs * head_mask
380
 
381
+ # change view [batch_size, num_heads, q_length, kv_length]
382
+ attention_probs_reshaped = attention_probs.view(batch_size, self.num_heads, query_length, kv_length)
383
 
384
  # matmul: [batch_size * num_heads, q_length, head_dim]
385
+ context_layer = (attention_probs_reshaped @ value_layer_).flatten(0, 1)
386
 
387
  # change view [batch_size, num_heads, q_length, head_dim]
388
  context_layer = self._merge_heads(context_layer)
389
 
390
  output_tensor = self.dense(context_layer)
391
 
 
392
  if output_attentions:
393
+ return output_tensor, present, attention_probs
394
+ else:
395
+ return output_tensor, present
396
 
397
 
398
+ class FalconMLP(nn.Module):
399
  def __init__(self, config: RWConfig):
400
  super().__init__()
401
  hidden_size = config.hidden_size
402
 
403
+ self.dense_h_to_4h = FalconLinear(hidden_size, 4 * hidden_size, bias=config.bias)
404
  self.act = nn.GELU()
405
+ self.dense_4h_to_h = FalconLinear(4 * hidden_size, hidden_size, bias=config.bias)
406
  self.hidden_dropout = config.hidden_dropout
407
 
408
  def forward(self, x: torch.Tensor) -> torch.Tensor:
 
411
  return x
412
 
413
 
414
+ class FalconDecoderLayer(nn.Module):
415
  def __init__(self, config: RWConfig):
416
  super().__init__()
417
  hidden_size = config.hidden_size
418
+ self.num_heads = config.num_attention_heads
419
+ self.self_attention = FalconAttention(config)
420
+ self.mlp = FalconMLP(config)
 
 
 
 
 
 
 
 
 
421
  self.hidden_dropout = config.hidden_dropout
 
422
  self.config = config
423
 
424
+ if config.new_decoder_architecture:
425
+ # The layer norm before self-attention
426
+ self.ln_attn = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
427
+ # The layer norm before the MLP
428
+ self.ln_mlp = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
429
+ else:
430
+ self.input_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
431
+ if not config.parallel_attn:
432
+ self.post_attention_layernorm = LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
433
+
434
  def forward(
435
  self,
436
  hidden_states: torch.Tensor,
437
+ alibi: Optional[torch.Tensor],
438
  attention_mask: torch.Tensor,
439
  layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
440
  head_mask: Optional[torch.Tensor] = None,
441
  use_cache: bool = False,
442
  output_attentions: bool = False,
443
  ):
 
 
444
  residual = hidden_states
445
 
446
+ if self.config.new_decoder_architecture:
447
+ attention_layernorm_out = self.ln_attn(hidden_states)
448
+ mlp_layernorm_out = self.ln_mlp(hidden_states)
449
+ else:
450
+ attention_layernorm_out = self.input_layernorm(hidden_states)
451
+
452
  # Self attention.
453
  attn_outputs = self.self_attention(
454
+ attention_layernorm_out,
455
  layer_past=layer_past,
456
  attention_mask=attention_mask,
457
  alibi=alibi,
 
462
 
463
  attention_output = attn_outputs[0]
464
 
465
+ if not self.config.new_decoder_architecture:
466
+ if self.config.parallel_attn:
467
+ mlp_layernorm_out = attention_layernorm_out
468
+ else:
469
+ residual = dropout_add(
470
+ attention_output, residual, self.config.attention_dropout, training=self.training
471
+ )
472
+ mlp_layernorm_out = self.post_attention_layernorm(residual)
473
 
474
  outputs = attn_outputs[1:]
475
 
476
  # MLP.
477
+ mlp_output = self.mlp(mlp_layernorm_out)
478
 
479
+ if self.config.new_decoder_architecture or self.config.parallel_attn:
480
  mlp_output += attention_output
481
 
482
  output = dropout_add(mlp_output, residual, self.config.hidden_dropout, training=self.training)
 
489
  return outputs # hidden_states, present, attentions
490
 
491
 
492
+ FALCON_START_DOCSTRING = r"""
493
+
494
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
495
+ library implements for all its model (such as downloading or saving, resizing the input embeddings etc.)
496
+
497
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
498
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
499
+ and behavior.
500
+
501
+ Parameters:
502
+ config ([`RWConfig`]): Model configuration class with all the parameters of the model.
503
+ Initializing with a config file does not load the weights associated with the model, only the
504
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
505
+ """
506
+
507
+ FALCON_INPUTS_DOCSTRING = r"""
508
+ Args:
509
+ input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
510
+ `input_ids_length` = `sequence_length` if `past_key_values` is `None` else `past_key_values[0][0].shape[2]`
511
+ (`sequence_length` of input past key value states). Indices of input sequence tokens in the vocabulary.
512
+
513
+ If `past_key_values` is used, only `input_ids` that do not have their past calculated should be passed as
514
+ `input_ids`.
515
+
516
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
517
+ [`PreTrainedTokenizer.__call__`] for details.
518
+
519
+ [What are input IDs?](../glossary#input-ids)
520
+ past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.num_hidden_layers`):
521
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
522
+ `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
523
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
524
+
525
+ Each element of `past_key_values` is a tuple (past_key, past_value):
526
+ - past_key: [batch_size * num_heads, head_dim, kv_length]
527
+ - past_value: [batch_size * num_heads, kv_length, head_dim]
528
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
529
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
530
+
531
+ - 1 for tokens that are **not masked**,
532
+ - 0 for tokens that are **masked**.
533
+
534
+ [What are attention masks?](../glossary#attention-mask)
535
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
536
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
537
+
538
+ - 1 indicates the head is **not masked**,
539
+ - 0 indicates the head is **masked**.
540
+
541
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
542
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
543
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
544
+ model's internal embedding lookup matrix.
545
+
546
+ If `past_key_values` is used, optionally only the last `inputs_embeds` have to be input (see
547
+ `past_key_values`).
548
+ use_cache (`bool`, *optional*):
549
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
550
+ `past_key_values`).
551
+ output_attentions (`bool`, *optional*):
552
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
553
+ tensors for more detail.
554
+ output_hidden_states (`bool`, *optional*):
555
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
556
+ more detail.
557
+ return_dict (`bool`, *optional*):
558
+ Whether or not to return a [`~file_utils.ModelOutput`] instead of a plain tuple.
559
+ """
560
+
561
+
562
  class RWPreTrainedModel(PreTrainedModel):
 
563
  """
564
  An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
565
  models.
 
568
  config_class = RWConfig
569
  base_model_prefix = "transformer"
570
  supports_gradient_checkpointing = True
571
+ _no_split_modules = ["FalconDecoderLayer"]
572
 
573
  def __init__(self, *inputs, **kwargs):
574
  super().__init__(*inputs, **kwargs)
575
 
576
  def _init_weights(self, module: nn.Module):
577
  """Initialize the weights."""
578
+ if isinstance(module, nn.Linear) or isinstance(module, FalconLinear):
579
  # Slightly different from the TF version which uses truncated_normal for initialization
580
  # cf https://github.com/pytorch/pytorch/pull/5617
581
  module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
 
589
  module.bias.data.zero_()
590
  module.weight.data.fill_(1.0)
591
 
592
+ # Copied from transformers.models.bloom.modeling_bloom.BloomPreTrainedModel._set_gradient_checkpointing with BloomModel->RWModel
593
  def _set_gradient_checkpointing(self, module: nn.Module, value: bool = False):
594
  if isinstance(module, RWModel):
595
  module.gradient_checkpointing = value
596
 
597
  @staticmethod
598
+ def _convert_cache_to_standard_format(
599
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
600
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
601
  """
602
  Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
603
  num_heads, ...]))
604
  """
605
+ batch_size_times_num_heads, kv_length, head_dim = past_key_value[0][0].shape
606
+ # [batch_size * self.num_heads, kv_length, head_dim] -> [batch_size, num_heads, kv_length, head_dim]
607
+ # Note that don't want to use self.num_attention_heads because the number of heads may vary depending
608
+ # on whether we use multi_query attention.
609
  num_heads = batch_size_times_num_heads // batch_size
 
 
610
  return tuple(
611
  (
612
+ layer_past[0].view(batch_size, num_heads, kv_length, head_dim),
613
+ layer_past[1].view(batch_size, num_heads, kv_length, head_dim),
614
  )
615
  for layer_past in past_key_value
616
  )
 
619
  def _convert_to_rw_cache(
620
  past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
621
  ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
622
+ batch_size, num_heads, kv_length, head_dim = past_key_value[0][0].shape
623
  batch_size_times_num_heads = batch_size * num_heads
624
+ # [batch_size, num_heads, kv_length, head_dim] -> [batch_size * num_heads, kv_length, head_dim]
 
625
  return tuple(
626
  (
627
+ layer_past[0].view(batch_size_times_num_heads, kv_length, head_dim),
628
+ layer_past[1].view(batch_size_times_num_heads, kv_length, head_dim),
629
  )
630
  for layer_past in past_key_value
631
  )
632
 
633
 
634
+ @add_start_docstrings(
635
+ "The bare Falcon Model transformer outputting raw hidden-states without any specific head on top.",
636
+ FALCON_START_DOCSTRING,
637
+ )
638
  class RWModel(RWPreTrainedModel):
639
  def __init__(self, config: RWConfig):
640
  super().__init__(config)
641
 
642
  self.embed_dim = config.hidden_size
643
+ self.num_heads = config.num_attention_heads
644
+ self.use_alibi = config.alibi
645
 
646
  # Embedding + LN Embedding
647
  self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim)
648
 
649
  # Transformer blocks
650
+ self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)])
651
 
652
  # Final Layer Norm
653
  self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
 
660
  def get_input_embeddings(self):
661
  return self.word_embeddings
662
 
663
+ @staticmethod
664
  def _prepare_attn_mask(
665
+ attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
666
  ) -> torch.BoolTensor:
667
+ # Create a causal mask
668
+ # The attention mask we receive as input should cover the whole extended sequence, including any past
669
+ # cache, so its shape should be [batch_size, seq_length + past_key_values_length]
670
+ # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length]
671
+ if input_shape[1] + past_key_values_length != attention_mask.shape[1]:
672
+ raise ValueError(
673
+ "Attention mask shape should be (batch_size, seq_length + past_key_values_length)"
674
+ f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length"
675
+ f" {past_key_values_length}."
676
+ )
677
  combined_attention_mask = None
678
  device = attention_mask.device
679
+ _, seq_length = input_shape
680
 
681
+ if seq_length > 1:
682
  combined_attention_mask = _make_causal_mask(
683
  input_shape, device=device, past_key_values_length=past_key_values_length
684
  )
685
 
686
+ # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length]
687
+ expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length)
688
  combined_attention_mask = (
689
  expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
690
  )
 
694
  def set_input_embeddings(self, new_embeddings: torch.Tensor):
695
  self.word_embeddings = new_embeddings
696
 
697
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
698
+ @add_code_sample_docstrings(
699
+ checkpoint=_CHECKPOINT_FOR_DOC,
700
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
701
+ config_class=_CONFIG_FOR_DOC,
702
+ )
703
  def forward(
704
  self,
705
  input_ids: Optional[torch.LongTensor] = None,
 
711
  output_attentions: Optional[bool] = None,
712
  output_hidden_states: Optional[bool] = None,
713
  return_dict: Optional[bool] = None,
 
714
  ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
 
 
 
 
 
 
 
 
 
 
715
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
716
  output_hidden_states = (
717
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
730
 
731
  if past_key_values is None:
732
  past_key_values = tuple([None] * len(self.h))
733
+ else:
734
+ past_key_values = self._convert_to_rw_cache(past_key_values)
735
 
736
  # Prepare head mask if needed
737
  # 1.0 in head_mask indicate we keep the head
738
  # attention_probs has shape batch_size x num_heads x N x N
739
  # head_mask has shape n_layer x batch x num_heads x N x N
740
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
741
 
742
  if inputs_embeds is None:
743
  inputs_embeds = self.word_embeddings(input_ids)
 
749
  all_hidden_states = () if output_hidden_states else None
750
 
751
  # Compute alibi tensor: check build_alibi_tensor documentation
 
752
  past_key_values_length = 0
753
  if past_key_values[0] is not None:
754
+ past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format
 
755
  if attention_mask is None:
756
+ attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device)
757
  else:
758
  attention_mask = attention_mask.to(hidden_states.device)
759
 
760
+ if self.use_alibi:
761
  alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype)
762
  else:
763
  alibi = None
 
769
  )
770
 
771
  for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
 
772
  if output_hidden_states:
773
  all_hidden_states = all_hidden_states + (hidden_states,)
774
 
775
  if self.gradient_checkpointing and self.training:
 
776
  if use_cache:
777
  logger.warning(
778
  "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
 
817
  if output_hidden_states:
818
  all_hidden_states = all_hidden_states + (hidden_states,)
819
 
820
+ if presents is not None:
821
+ presents = self._convert_cache_to_standard_format(presents, batch_size)
822
+
823
  if not return_dict:
824
  return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
825
 
 
831
  )
832
 
833
 
834
+ @add_start_docstrings(
835
+ "The Falcon Model transformer with a language modeling head on top (linear layer with weights tied to the input embeddings).",
836
+ FALCON_START_DOCSTRING,
837
+ )
838
  class RWForCausalLM(RWPreTrainedModel):
839
+ _tied_weights_keys = ["lm_head.weight"]
840
 
841
  def __init__(self, config: RWConfig):
842
  super().__init__(config)
 
855
  def prepare_inputs_for_generation(
856
  self,
857
  input_ids: torch.LongTensor,
858
+ past_key_values: Optional[torch.Tensor] = None,
859
  attention_mask: Optional[torch.Tensor] = None,
860
  **kwargs,
861
  ) -> dict:
862
+ if past_key_values is not None:
863
+ input_ids = input_ids[:, -1:]
 
 
 
 
 
864
 
865
  return {
866
  "input_ids": input_ids,
867
+ "past_key_values": past_key_values,
868
  "use_cache": kwargs.get("use_cache"),
869
  "attention_mask": attention_mask,
870
  }
871
 
872
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
873
+ @add_code_sample_docstrings(
874
+ checkpoint=_CHECKPOINT_FOR_DOC,
875
+ output_type=CausalLMOutputWithCrossAttentions,
876
+ config_class=_CONFIG_FOR_DOC,
877
+ )
878
  def forward(
879
  self,
880
  input_ids: Optional[torch.LongTensor] = None,
 
887
  output_attentions: Optional[bool] = None,
888
  output_hidden_states: Optional[bool] = None,
889
  return_dict: Optional[bool] = None,
 
890
  ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
891
  r"""
892
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
 
894
  `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
895
  are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
896
  """
 
 
 
 
 
 
 
 
 
897
 
898
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
899
 
 
946
 
947
  Output shares the same memory storage as `past`.
948
  """
 
949
 
950
  # Get a copy of `beam_idx` on all the devices where we need those indices.
951
  device_to_beam_idx = {
 
956
  layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
957
  layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
958
  )
959
+ for layer_past in past
960
  )
961
+ return reordered_past
962
 
963
 
964
+ @add_start_docstrings(
965
+ """
966
+ The Falcon Model transformer with a sequence classification head on top (linear layer).
967
+
968
+ [`RWForSequenceClassification`] uses the last token in order to do the classification, as other causal models
969
+ (e.g. GPT-1) do.
970
+
971
+ Since it does classification on the last token, it requires to know the position of the last token. If a
972
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
973
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
974
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
975
+ each row of the batch).
976
+ """,
977
+ FALCON_START_DOCSTRING,
978
+ )
979
  class RWForSequenceClassification(RWPreTrainedModel):
 
 
980
  def __init__(self, config: RWConfig):
981
  super().__init__(config)
982
  self.num_labels = config.num_labels
 
986
  # Initialize weights and apply final processing
987
  self.post_init()
988
 
989
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
990
+ @add_code_sample_docstrings(
991
+ checkpoint=_CHECKPOINT_FOR_DOC,
992
+ output_type=SequenceClassifierOutputWithPast,
993
+ config_class=_CONFIG_FOR_DOC,
994
+ )
995
  def forward(
996
  self,
997
  input_ids: Optional[torch.LongTensor] = None,
 
1004
  output_attentions: Optional[bool] = None,
1005
  output_hidden_states: Optional[bool] = None,
1006
  return_dict: Optional[bool] = None,
 
1007
  ) -> Union[Tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1008
  r"""
1009
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1011
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1012
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1013
  """
 
 
 
 
 
 
 
 
 
1014
 
1015
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1016
 
 
1085
  )
1086
 
1087
 
1088
+ @add_start_docstrings(
1089
+ """
1090
+ Falcon Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1091
+ Named-Entity-Recognition (NER) tasks.
1092
+ """,
1093
+ FALCON_START_DOCSTRING,
1094
+ )
1095
  class RWForTokenClassification(RWPreTrainedModel):
 
 
1096
  def __init__(self, config: RWConfig):
1097
  super().__init__(config)
1098
  self.num_labels = config.num_labels
1099
 
1100
  self.transformer = RWModel(config)
1101
+ if getattr(config, "classifier_dropout", None) is not None:
1102
  classifier_dropout = config.classifier_dropout
1103
+ elif getattr(config, "hidden_dropout", None) is not None:
1104
  classifier_dropout = config.hidden_dropout
1105
  else:
1106
  classifier_dropout = 0.1
 
1110
  # Initialize weights and apply final processing
1111
  self.post_init()
1112
 
1113
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1114
+ @add_code_sample_docstrings(
1115
+ checkpoint=_CHECKPOINT_FOR_DOC,
1116
+ output_type=TokenClassifierOutput,
1117
+ config_class=_CONFIG_FOR_DOC,
1118
+ )
1119
  def forward(
1120
  self,
1121
  input_ids: Optional[torch.LongTensor] = None,
 
1128
  output_attentions: Optional[bool] = None,
1129
  output_hidden_states: Optional[bool] = None,
1130
  return_dict: Optional[bool] = None,
 
1131
  ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
1132
  r"""
1133
  labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
 
1135
  config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1136
  `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1137
  """
 
 
 
 
 
 
 
 
 
1138
 
1139
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1140
 
 
1158
  if labels is not None:
1159
  batch_size, seq_length = labels.shape
1160
  loss_fct = CrossEntropyLoss()
1161
+ loss = loss_fct(
1162
+ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)
1163
+ )
1164
 
1165
  if not return_dict:
1166
  output = (logits,) + transformer_outputs[2:]
 
1174
  )
1175
 
1176
 
1177
+ @add_start_docstrings(
1178
+ """
1179
+ The Falcon Model transformer with a span classification head on top for extractive question-answering tasks like
1180
+ SQuAD (a linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1181
+ """,
1182
+ FALCON_START_DOCSTRING,
1183
+ )
1184
  class RWForQuestionAnswering(RWPreTrainedModel):
 
 
1185
  def __init__(self, config):
1186
  super().__init__(config)
1187
  self.transformer = RWModel(config)
 
1190
  # Initialize weights and apply final processing
1191
  self.post_init()
1192
 
1193
+ @add_start_docstrings_to_model_forward(FALCON_INPUTS_DOCSTRING)
1194
  def forward(
1195
  self,
1196
  input_ids: Optional[torch.LongTensor] = None,
1197
  attention_mask: Optional[torch.FloatTensor] = None,
 
1198
  head_mask: Optional[torch.FloatTensor] = None,
1199
  inputs_embeds: Optional[torch.FloatTensor] = None,
1200
  start_positions: Optional[torch.LongTensor] = None,
 
1218
  outputs = self.transformer(
1219
  input_ids,
1220
  attention_mask=attention_mask,
 
1221
  head_mask=head_mask,
1222
  inputs_embeds=inputs_embeds,
1223
  output_attentions=output_attentions,