x54-729 commited on
Commit
6b55bce
·
1 Parent(s): dd09602

fix import error

Browse files
Files changed (1) hide show
  1. modeling_internlm2.py +18 -5
modeling_internlm2.py CHANGED
@@ -51,6 +51,19 @@ logger = logging.get_logger(__name__)
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
55
  def _get_unpad_data(attention_mask):
56
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
@@ -502,13 +515,11 @@ class InternLM2FlashAttention2(InternLM2Attention):
502
  softmax_scale (`float`, *optional*):
503
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
504
  """
505
- from flash_attn import flash_attn_func, flash_attn_varlen_func
506
- from flash_attn.bert_padding import pad_input
507
  # Contains at least one padding token in the sequence
508
  causal = self.is_causal and query_length != 1
509
  if attention_mask is not None:
510
  batch_size = query_states.shape[0]
511
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
512
  query_states, key_states, value_states, attention_mask, query_length
513
  )
514
 
@@ -536,8 +547,7 @@ class InternLM2FlashAttention2(InternLM2Attention):
536
 
537
  return attn_output
538
 
539
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
540
- from flash_attn.bert_padding import index_first_axis, unpad_input
541
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
542
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
543
 
@@ -842,6 +852,9 @@ class InternLM2Model(InternLM2PreTrainedModel):
842
 
843
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
844
 
 
 
 
845
  # retrieve input_ids and inputs_embeds
846
  if input_ids is not None and inputs_embeds is not None:
847
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
 
51
 
52
  _CONFIG_FOR_DOC = "InternLM2Config"
53
 
54
+ flash_attn_func, flash_attn_varlen_func = None, None
55
+ pad_input, index_first_axis, unpad_input = None, None, None
56
+ def _import_flash_attn():
57
+ global flash_attn_func, flash_attn_varlen_func
58
+ global pad_input, index_first_axis, unpad_input
59
+ try:
60
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
61
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
62
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
63
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
64
+ except ImportError:
65
+ raise ImportError("flash_attn is not installed.")
66
+
67
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
68
  def _get_unpad_data(attention_mask):
69
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
 
515
  softmax_scale (`float`, *optional*):
516
  The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
517
  """
 
 
518
  # Contains at least one padding token in the sequence
519
  causal = self.is_causal and query_length != 1
520
  if attention_mask is not None:
521
  batch_size = query_states.shape[0]
522
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
523
  query_states, key_states, value_states, attention_mask, query_length
524
  )
525
 
 
547
 
548
  return attn_output
549
 
550
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
551
  indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
552
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
553
 
 
852
 
853
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
854
 
855
+ if self.config.attn_implementation == "flash_attention_2":
856
+ _import_flash_attn()
857
+
858
  # retrieve input_ids and inputs_embeds
859
  if input_ids is not None and inputs_embeds is not None:
860
  raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")