x54-729
commited on
Commit
·
6b55bce
1
Parent(s):
dd09602
fix import error
Browse files- modeling_internlm2.py +18 -5
modeling_internlm2.py
CHANGED
@@ -51,6 +51,19 @@ logger = logging.get_logger(__name__)
|
|
51 |
|
52 |
_CONFIG_FOR_DOC = "InternLM2Config"
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
55 |
def _get_unpad_data(attention_mask):
|
56 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
@@ -502,13 +515,11 @@ class InternLM2FlashAttention2(InternLM2Attention):
|
|
502 |
softmax_scale (`float`, *optional*):
|
503 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
504 |
"""
|
505 |
-
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
506 |
-
from flash_attn.bert_padding import pad_input
|
507 |
# Contains at least one padding token in the sequence
|
508 |
causal = self.is_causal and query_length != 1
|
509 |
if attention_mask is not None:
|
510 |
batch_size = query_states.shape[0]
|
511 |
-
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self.
|
512 |
query_states, key_states, value_states, attention_mask, query_length
|
513 |
)
|
514 |
|
@@ -536,8 +547,7 @@ class InternLM2FlashAttention2(InternLM2Attention):
|
|
536 |
|
537 |
return attn_output
|
538 |
|
539 |
-
def
|
540 |
-
from flash_attn.bert_padding import index_first_axis, unpad_input
|
541 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
542 |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
543 |
|
@@ -842,6 +852,9 @@ class InternLM2Model(InternLM2PreTrainedModel):
|
|
842 |
|
843 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
844 |
|
|
|
|
|
|
|
845 |
# retrieve input_ids and inputs_embeds
|
846 |
if input_ids is not None and inputs_embeds is not None:
|
847 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
|
|
51 |
|
52 |
_CONFIG_FOR_DOC = "InternLM2Config"
|
53 |
|
54 |
+
flash_attn_func, flash_attn_varlen_func = None, None
|
55 |
+
pad_input, index_first_axis, unpad_input = None, None, None
|
56 |
+
def _import_flash_attn():
|
57 |
+
global flash_attn_func, flash_attn_varlen_func
|
58 |
+
global pad_input, index_first_axis, unpad_input
|
59 |
+
try:
|
60 |
+
from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
|
61 |
+
from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
|
62 |
+
flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
|
63 |
+
pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
|
64 |
+
except ImportError:
|
65 |
+
raise ImportError("flash_attn is not installed.")
|
66 |
+
|
67 |
# Copied from transformers.models.llama.modeling_llama._get_unpad_data
|
68 |
def _get_unpad_data(attention_mask):
|
69 |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
|
|
|
515 |
softmax_scale (`float`, *optional*):
|
516 |
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
|
517 |
"""
|
|
|
|
|
518 |
# Contains at least one padding token in the sequence
|
519 |
causal = self.is_causal and query_length != 1
|
520 |
if attention_mask is not None:
|
521 |
batch_size = query_states.shape[0]
|
522 |
+
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
|
523 |
query_states, key_states, value_states, attention_mask, query_length
|
524 |
)
|
525 |
|
|
|
547 |
|
548 |
return attn_output
|
549 |
|
550 |
+
def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
|
|
|
551 |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
|
552 |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
|
553 |
|
|
|
852 |
|
853 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
854 |
|
855 |
+
if self.config.attn_implementation == "flash_attention_2":
|
856 |
+
_import_flash_attn()
|
857 |
+
|
858 |
# retrieve input_ids and inputs_embeds
|
859 |
if input_ids is not None and inputs_embeds is not None:
|
860 |
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|