Code for quantization of FLUX1.dev via optimum quanto

#2
by timmerscher - opened

Heyho, first of all, great work!
I really like the quantized model and it proofed really useful for me!

For my current project, I would like to finetune the model for a specific task (via Dreambooth). I think the already quantified Flux1.dev-qint8 will perform quite bad, if you think about, that finetuning works on f32 or f16 and slighty changes the weights. With qint8 the weights already got rounded by a lot, so not sure, if fine-tuning is possible.
Long story short, do you still have the script, that you used for quantization?

Because then I would fine-tune Flux1.dev myself and later quantize it with optimum quanto.
It would safe me some time and I would be really thankful!

Cheers

import json
import torch
import diffusers
from optimum import quanto
from optimum.quanto import quantization_map
from safetensors.torch import save_file

dtype = torch.bfloat16
pipe = diffusers.AutoPipelineForText2Image.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=dtype)

# DiT
quanto.quantize(pipe.transformer, weights=quanto.qint8)
quanto.freeze(pipe.transformer)

save_file(pipe.transformer.state_dict(), '/mnt/DataSSD/AI/models/flux/transformer/diffusion_pytorch_model.safetensors')
with open('/mnt/DataSSD/AI/models/flux/transformer/quantization_map.json', "w") as f:
  json.dump(quantization_map(pipe.transformer), f)


# T5
quanto.quantize(pipe.text_encoder_2, weights=quanto.qint8)
quanto.freeze(pipe.text_encoder_2)

pipe.text_encoder_2.save_pretrained("/mnt/DataSSD/AI/models/flux/text_encoder_2")
with open('/mnt/DataSSD/AI/models/flux/text_encoder_2/quantization_map.json', "w") as f:
  json.dump(quantization_map(pipe.text_encoder_2), f)

Thanks mate! you're the mvp here

timmerscher changed discussion status to closed

Sign up or log in to comment