ValueError: Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!

#29
by jhn9803 - opened

I downloaded the Hugging Face model meta-llama/Llama-3.2-90B-Vision-Instruct and ran the example inference code, but an issue occurred.

output = model.generate(**inputs, max_new_tokens=30)

When running the code above, the following error occurred:

ValueError: Cross attention layer can't find neither 'cross_attn_states' nor cached values for key/values!

The inputs contained the following keys:

['input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask']

When checking the source code of the Transformers library, it seems that model.generate can proceed sequentially without cross_attn_states in the inputs, and it appears to create cross_attn_states during the process. However, encountering such an error is quite confusing.

I would be really grateful if anyone could provide a solution to this issue.

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[12], line 1
----> 1 output = model.generate(**inputs, max_new_tokens=30, do_sample=False)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:2252, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2244     input_ids, model_kwargs = self._expand_inputs_for_generation(
   2245         input_ids=input_ids,
   2246         expand_size=generation_config.num_return_sequences,
   2247         is_encoder_decoder=self.config.is_encoder_decoder,
   2248         **model_kwargs,
   2249     )
   2251     # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2252     result = self._sample(
   2253         input_ids,
   2254         logits_processor=prepared_logits_processor,
   2255         stopping_criteria=prepared_stopping_criteria,
   2256         generation_config=generation_config,
   2257         synced_gpus=synced_gpus,
   2258         streamer=streamer,
   2259         **model_kwargs,
   2260     )
   2262 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
   2263     # 11. prepare beam search scorer
   2264     beam_scorer = BeamSearchScorer(
   2265         batch_size=batch_size,
   2266         num_beams=generation_config.num_beams,
   (...)
   2271         max_length=generation_config.max_length,
   2272     )

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:3254, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
   3252     is_prefill = False
   3253 else:
-> 3254     outputs = model_forward(**model_inputs, return_dict=True)
   3256 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
   3257 model_kwargs = self._update_model_kwargs_for_generation(
   3258     outputs,
   3259     model_kwargs,
   3260     is_encoder_decoder=self.config.is_encoder_decoder,
   3261 )

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:2127, in MllamaForConditionalGeneration.forward(self, input_ids, pixel_values, aspect_ratio_mask, aspect_ratio_ids, attention_mask, cross_attention_mask, cross_attention_states, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
   2124     cross_attention_mask = cross_attention_mask[:, :, cache_position]
   2125     full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
-> 2127 outputs = self.language_model(
   2128     input_ids=input_ids,
   2129     attention_mask=attention_mask,
   2130     position_ids=position_ids,
   2131     cross_attention_states=cross_attention_states,
   2132     cross_attention_mask=cross_attention_mask,
   2133     full_text_row_masked_out_mask=full_text_row_masked_out_mask,
   2134     past_key_values=past_key_values,
   2135     use_cache=use_cache,
   2136     inputs_embeds=inputs_embeds,
   2137     labels=labels,
   2138     output_hidden_states=output_hidden_states,
   2139     output_attentions=output_attentions,
   2140     return_dict=return_dict,
   2141     cache_position=cache_position,
   2142     num_logits_to_keep=num_logits_to_keep,
   2143 )
   2145 return outputs

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:1935, in MllamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
   1932 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
   1934 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1935 outputs = self.model(
   1936     input_ids=input_ids,
   1937     cross_attention_states=cross_attention_states,
   1938     attention_mask=attention_mask,
   1939     position_ids=position_ids,
   1940     cross_attention_mask=cross_attention_mask,
   1941     full_text_row_masked_out_mask=full_text_row_masked_out_mask,
   1942     past_key_values=past_key_values,
   1943     inputs_embeds=inputs_embeds,
   1944     use_cache=use_cache,
   1945     output_attentions=output_attentions,
   1946     output_hidden_states=output_hidden_states,
   1947     return_dict=return_dict,
   1948     cache_position=cache_position,
   1949 )
   1951 hidden_states = outputs[0]
   1952 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:1797, in MllamaTextModel.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
   1782     layer_outputs = self._gradient_checkpointing_func(
   1783         decoder_layer.__call__,
   1784         hidden_states,
   (...)
   1794         position_embeddings,
   1795     )
   1796 else:
-> 1797     layer_outputs = decoder_layer(
   1798         hidden_states,
   1799         cross_attention_states=cross_attention_states,
   1800         cross_attention_mask=cross_attention_mask,
   1801         attention_mask=causal_mask,
   1802         full_text_row_masked_out_mask=full_text_row_masked_out_mask,
   1803         position_ids=position_ids,
   1804         past_key_value=past_key_values,
   1805         output_attentions=output_attentions,
   1806         use_cache=use_cache,
   1807         cache_position=cache_position,
   1808         position_embeddings=position_embeddings,
   1809     )
   1811 hidden_states = layer_outputs[0]
   1813 if use_cache:

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:951, in MllamaCrossAttentionDecoderLayer.forward(self, hidden_states, cross_attention_states, cross_attention_mask, attention_mask, full_text_row_masked_out_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings)
    948 residual = hidden_states
    949 hidden_states = self.input_layernorm(hidden_states)
--> 951 hidden_states, attn_weights, past_key_value = self.cross_attn(
    952     hidden_states=hidden_states,
    953     attention_mask=cross_attention_mask,
    954     cross_attention_states=cross_attention_states,
    955     past_key_value=past_key_value,
    956     output_attentions=output_attentions,
    957     cache_position=cache_position,
    958 )
    959 hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
    961 residual = hidden_states

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
   1509     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1510 else:
-> 1511     return self._call_impl(*args, **kwargs)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
   1515 # If we don't have any hooks, we want to skip the rest of the logic in
   1516 # this function, and just call forward.
   1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1518         or _global_backward_pre_hooks or _global_backward_hooks
   1519         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520     return forward_call(*args, **kwargs)
   1522 try:
   1523     result = None

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
    168         output = module._old_forward(*args, **kwargs)
    169 else:
--> 170     output = module._old_forward(*args, **kwargs)
    171 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:575, in MllamaTextCrossSdpaAttention.forward(self, hidden_states, cross_attention_states, past_key_value, attention_mask, output_attentions, use_cache, cache_position)
    570     key_states, value_states = (
    571         past_key_value.key_cache[self.layer_idx],
    572         past_key_value.value_cache[self.layer_idx],
    573     )
    574 else:
--> 575     raise ValueError(
    576         "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
    577     )
    579 key_states = repeat_kv(key_states, self.num_key_value_groups)
    580 value_states = repeat_kv(value_states, self.num_key_value_groups)

ValueError: Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!
` ``

Sign up or log in to comment