|
import torch |
|
import torch.nn as nn |
|
|
|
from gptq import * |
|
from modelutils import * |
|
from quant import * |
|
from transformers import BloomForCausalLM as LM |
|
|
|
class SakuraForCausalLM(LM): |
|
def __init__(self,*args,**kwargs): |
|
def noop(*args, **kwargs): |
|
pass |
|
torch.nn.init.kaiming_uniform_ = noop |
|
torch.nn.init.uniform_ = noop |
|
torch.nn.init.normal_ = noop |
|
torch.set_default_dtype(torch.half) |
|
transformers.modeling_utils._init_weights = False |
|
torch.set_default_dtype(torch.half) |
|
super().__init__(*args,**kwargs) |
|
torch.set_default_dtype(torch.float) |
|
self.eval() |
|
layers = find_layers(self) |
|
for name in ['lm_head']: |
|
if name in layers: |
|
del layers[name] |
|
make_quant(self, layers, 8, 128) |