RuntimeError: FlashAttention only support fp16 and bf16 data type during fine tuning.
The hyper params i am using
training_config = {
"bf16": True,
"do_eval": False,
"learning_rate": 0.00001,
"lr_scheduler_type": "cosine",
"log_level": "info",
"logging_steps": 30,
"logging_strategy": "steps",
"num_train_epochs": 5,
"max_steps": -1,
"output_dir": "./workspace/checkpoint_dir",
"overwrite_output_dir": True,
"per_device_eval_batch_size": 4,
"remove_unused_columns": True,
"save_steps": 100,
"save_total_limit": 1,
"seed": 0,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs":{"use_reentrant": False},
"gradient_accumulation_steps": 1,
"warmup_ratio": 0.2,
}
I am loading the model using
checkpoint_path = "microsoft/Phi-3-small-8k-instruct"
model_kwargs = dict(
use_cache=False,
trust_remote_code=True,
attn_implementation="flash_attention_2",
torch_dtype="auto",
device_map=None,
)
model = AutoModelForCausalLM.from_pretrained(checkpoint_path, **model_kwargs)
And start training
trainer = SFTTrainer(
model=model,
args=train_conf,
train_dataset=processed_dataset,
max_seq_length=8192,
dataset_text_field="text",
tokenizer=tokenizer,
packing=True
)
train_result = trainer.train()
I am getting the following error
107 # if out.isnan().any() or softmax_lse.isnan().any():
108 # breakpoint()
109 return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state
RuntimeError: FlashAttention only support fp16 and bf16 data type
I used the exact same config while fine tuning phi-3-mini-128k without any issues. Is anyone else facing the same issue?
Hi !
Flash attention as well as the block-sparse kernel for attention require the model to be trained on fp16 / bf16. Is there a reason why bfloat16 might not work for your use-case ?
Hi,
Getting the same error, even with bf16 = True in training arg
Not sure if it's the correct fix. Here is how I make it work
https://huggingface.co/microsoft/Phi-3-small-8k-instruct/blob/f5527db8a43fc9a4bf17c5b754251e1efe1d4ad3/positional_embedding.py#L269
update the dtype of the q and k after the rotary mapping:
return (
apply_rotary_pos_emb(
q, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
).to(q.dtype),
apply_rotary_pos_emb(
k, cos_cached[seqlen_offset:seq_len], sin_cached[seqlen_offset:seq_len], seq_dimension=seq_dimension
).to(q.dtype),
)
Thx for the answer.
Happens when device_map = "auto" (or anything that is not None). Might be a problem related with flash attention and multi gpus training. If you have a fix, do not hesitate.
I will be doing another batch of training over the weekend, will try out @ecocytus11 solution. Thanks!
Facing the same issue, with 8k and 128k small model
@santyzenith
The solution suggested above seem to work. will update in the code.
https://huggingface.co/microsoft/Phi-3-small-8k-instruct/discussions/11#6659a17e8b11da966e8e510c
I'm still running into the same error in microsoft/Phi-3.5-MoE-instruct. Do you if that model is subject to the same type casting issue?