Duration to generate a caption for 1 image

#9
by apapiu - opened

Thanks for the awesome model! I am running the 4-bit model on a T4 in Colab and one generation is taking about 7 seconds. Even if I batch the images outputs = pipe(imgs, prompt=prompt, generate_kwargs={"max_new_tokens": 200}) I am note getting much of a speedup (i.e. for 4 images it takes ~28 seconds). Are there any other way to speed up text generation?

Llava Hugging Face org

Hi @apapiu
Thanks!
Make sure to pass bnb_4bit_compute_dtype=torch.float16 to the quantization config:

import torch
from transformers import LlavaForConditionalGeneration, BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)
model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf", quantization_config=quantization_config)

If you have access to an A100 GPU you can also use Flash-Attention-2. I need to double check if it works out of the box but you can also pass attn_implementation="sdpa" to from_pretrained which should lead to faster generation, let me get back to you on this

Llava Hugging Face org

Hi @apapiu
I just made https://github.com/huggingface/transformers/pull/28107 which should add SDPA support in Llava, it makes the model generation faster, you will be able to enable it through:

import torch
from transformers import pipeline, BitsAndBytesConfig
from PIL import Image    
import requests

model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

pipe = pipeline("image-to-text", model=model_id, model_kwargs={"load_in_4bit": True, "quantization_config": quantization_config, "attn_implementation": "sdpa"})
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"

image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image>\nWhat does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud\nASSISTANT:"

outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
print(outputs)
Llava Hugging Face org
edited Dec 18, 2023

Once that PR gets merged, you can force-dispatch SDAP to use flash-attention with the following script (works also on a T4):

import torch
from transformers import pipeline, BitsAndBytesConfig
from PIL import Image    
import requests

model_id = "llava-hf/llava-1.5-7b-hf"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16)

pipe = pipeline("image-to-text", model=model_id, model_kwargs={"load_in_4bit": True, "quantization_config": quantization_config})
url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"

image = Image.open(requests.get(url, stream=True).raw)
prompt = "USER: <image>\nWhat does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud\nASSISTANT:"

- outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
+     outputs = pipe(image, prompt=prompt, generate_kwargs={"max_new_tokens": 200})
print(outputs)

@ybelkada thanks for the reply! Unfortunately when I tried with and without the torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False) context I got the same generation length. (using 4.37.0.dev0 for transformers)

Another issue I identified was that when using a pipe: pipe(imgs, prompt=prompt, generate_kwargs={"max_new_tokens": 200})even if I put in a list of PIL images the processing happens sequentially (based on GPU memory). However using the transformer directly with model.generate solves the problem since for that I can pass in batches:


model_id = "llava-hf/llava-1.5-7b-hf"

model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True,
    attn_implementation="flash_attention_2"
).to('cuda')

#get a batch of images..

imgs = [to_pil(batch[0][i]) for i in range(bs)]
inputs = processor([prompt]*bs, imgs, return_tensors='pt').to(0, torch.float16)
output = model.generate(**inputs, max_new_tokens=200, do_sample=False)

note that I am not using the 4bit quantization (I did have to go to an A100 for that however). The above seems to work and is pretty fast (about 6 images per second).

My two remaining questions are:

  1. How can I make the pipe accept batched inputs?
  2. Is it expected for 4bit models to be significantly slower than fp16 models?
Llava Hugging Face org

Hi @apapiu
Thanks a lot for testing out!
1- I think for pipeline it indeed runs the operation sequentially, I am not sure how to perform batched generation with pipelines.
2- Yes it is expected ! check out this extensive blogpost on quantization schemes benchmarks: https://huggingface.co/blog/overview-quantization-transformers or this relevant doc page: https://huggingface.co/docs/transformers/quantization
I am planning on adding AWQ for llava: https://github.com/casper-hansen/AutoAWQ/pull/250 once that PR gets merged we'll be able to use llava + AWQ + fused modules which should be faster than fp16. Read more about it here: https://huggingface.co/docs/transformers/quantization#make-use-of-fused-modules

Gotcha no worries the pure transformer API works just fine. And these are great resources - thanks for sharing (and for writing some of them :)

Llava Hugging Face org

Awesome, thanks very much @apapiu !

Sign up or log in to comment