Fill-Mask
Transformers
PyTorch
Safetensors
English
nomic_bert
custom_code
zpn commited on
Commit
a5980af
·
verified ·
1 Parent(s): f173057

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +5 -11
modeling_hf_nomic_bert.py CHANGED
@@ -315,9 +315,8 @@ class NomicBertPreTrainedModel(PreTrainedModel):
315
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
316
  num_labels = kwargs.pop("num_labels", None)
317
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
318
- if rotary_scaling_factor:
319
- config.rotary_scaling_factor = rotary_scaling_factor
320
-
321
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
322
  config.n_positions = 2048
323
  if num_labels:
@@ -326,10 +325,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
326
  if "add_pooling_layer" in kwargs:
327
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
328
  else:
329
- if cls == NomicBertModel:
330
- model = cls(config, *inputs, add_pooling_layer=False)
331
- else:
332
- model = cls(config, *inputs)
333
  # TODO: fix this
334
  # Assuming we know what we're doing when loading from disk
335
  # Prob a bad assumption but i'm tired and want to train this asap
@@ -348,9 +344,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
348
  load_return = model.load_state_dict(state_dict, strict=False)
349
  else:
350
  # TODO: can probably check config class and see if we need to remap from a bert model
351
- state_dict = state_dict_from_pretrained(
352
- model_name, safe_serialization=kwargs.get("safe_serialization", False)
353
- )
354
  state_dict = remap_bert_state_dict(
355
  state_dict,
356
  config,
@@ -361,7 +355,7 @@ class NomicBertPreTrainedModel(PreTrainedModel):
361
  if ignore_mismatched_shapes:
362
  state_dict = filter_shapes(state_dict, model)
363
 
364
- load_return = model.load_state_dict(state_dict, strict=True)
365
  logger.warning(load_return)
366
  return model
367
 
 
315
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
316
  num_labels = kwargs.pop("num_labels", None)
317
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
318
+ strict = kwargs.pop("strict", True)
319
+ config.rotary_scaling_factor = rotary_scaling_factor
 
320
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
321
  config.n_positions = 2048
322
  if num_labels:
 
325
  if "add_pooling_layer" in kwargs:
326
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
327
  else:
328
+ model = cls(config, *inputs)
 
 
 
329
  # TODO: fix this
330
  # Assuming we know what we're doing when loading from disk
331
  # Prob a bad assumption but i'm tired and want to train this asap
 
344
  load_return = model.load_state_dict(state_dict, strict=False)
345
  else:
346
  # TODO: can probably check config class and see if we need to remap from a bert model
347
+ state_dict = state_dict_from_pretrained(model_name)
 
 
348
  state_dict = remap_bert_state_dict(
349
  state_dict,
350
  config,
 
355
  if ignore_mismatched_shapes:
356
  state_dict = filter_shapes(state_dict, model)
357
 
358
+ load_return = model.load_state_dict(state_dict, strict=strict)
359
  logger.warning(load_return)
360
  return model
361