regisss HF staff commited on
Commit
d811de4
·
verified ·
1 Parent(s): eb48137

Define CustomLlamaConfig

Browse files
Files changed (1) hide show
  1. modeling_llama.py +5 -1
modeling_llama.py CHANGED
@@ -58,6 +58,10 @@ logger = logging.get_logger(__name__)
58
  _CONFIG_FOR_DOC = "LlamaConfig"
59
 
60
 
 
 
 
 
61
  def _get_unpad_data(attention_mask):
62
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
63
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
@@ -929,7 +933,7 @@ LLAMA_START_DOCSTRING = r"""
929
  LLAMA_START_DOCSTRING,
930
  )
931
  class LlamaPreTrainedModel(PreTrainedModel):
932
- config_class = LlamaConfig
933
  base_model_prefix = "model"
934
  supports_gradient_checkpointing = True
935
  _no_split_modules = ["LlamaDecoderLayer"]
 
58
  _CONFIG_FOR_DOC = "LlamaConfig"
59
 
60
 
61
+ CustomLlamaConfig(LlamaConfig):
62
+ model_type = "custom_llama"
63
+
64
+
65
  def _get_unpad_data(attention_mask):
66
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
67
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
 
933
  LLAMA_START_DOCSTRING,
934
  )
935
  class LlamaPreTrainedModel(PreTrainedModel):
936
+ config_class = CustomLlamaConfig
937
  base_model_prefix = "model"
938
  supports_gradient_checkpointing = True
939
  _no_split_modules = ["LlamaDecoderLayer"]