ValueError: Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!
I downloaded the Hugging Face model meta-llama/Llama-3.2-90B-Vision-Instruct
and ran the example inference code, but an issue occurred.
output = model.generate(**inputs, max_new_tokens=30)
When running the code above, the following error occurred:
ValueError: Cross attention layer can't find neither 'cross_attn_states' nor cached values for key/values!
The inputs
contained the following keys:
['input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask']
When checking the source code of the Transformers library, it seems that model.generate
can proceed sequentially without cross_attn_states
in the inputs, and it appears to create cross_attn_states
during the process. However, encountering such an error is quite confusing.
I would be really grateful if anyone could provide a solution to this issue.
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[12], line 1
----> 1 output = model.generate(**inputs, max_new_tokens=30, do_sample=False)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:2252, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2244 input_ids, model_kwargs = self._expand_inputs_for_generation(
2245 input_ids=input_ids,
2246 expand_size=generation_config.num_return_sequences,
2247 is_encoder_decoder=self.config.is_encoder_decoder,
2248 **model_kwargs,
2249 )
2251 # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
-> 2252 result = self._sample(
2253 input_ids,
2254 logits_processor=prepared_logits_processor,
2255 stopping_criteria=prepared_stopping_criteria,
2256 generation_config=generation_config,
2257 synced_gpus=synced_gpus,
2258 streamer=streamer,
2259 **model_kwargs,
2260 )
2262 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2263 # 11. prepare beam search scorer
2264 beam_scorer = BeamSearchScorer(
2265 batch_size=batch_size,
2266 num_beams=generation_config.num_beams,
(...)
2271 max_length=generation_config.max_length,
2272 )
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py:3254, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, **model_kwargs)
3252 is_prefill = False
3253 else:
-> 3254 outputs = model_forward(**model_inputs, return_dict=True)
3256 # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
3257 model_kwargs = self._update_model_kwargs_for_generation(
3258 outputs,
3259 model_kwargs,
3260 is_encoder_decoder=self.config.is_encoder_decoder,
3261 )
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:2127, in MllamaForConditionalGeneration.forward(self, input_ids, pixel_values, aspect_ratio_mask, aspect_ratio_ids, attention_mask, cross_attention_mask, cross_attention_states, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep)
2124 cross_attention_mask = cross_attention_mask[:, :, cache_position]
2125 full_text_row_masked_out_mask = full_text_row_masked_out_mask[:, :, cache_position]
-> 2127 outputs = self.language_model(
2128 input_ids=input_ids,
2129 attention_mask=attention_mask,
2130 position_ids=position_ids,
2131 cross_attention_states=cross_attention_states,
2132 cross_attention_mask=cross_attention_mask,
2133 full_text_row_masked_out_mask=full_text_row_masked_out_mask,
2134 past_key_values=past_key_values,
2135 use_cache=use_cache,
2136 inputs_embeds=inputs_embeds,
2137 labels=labels,
2138 output_hidden_states=output_hidden_states,
2139 output_attentions=output_attentions,
2140 return_dict=return_dict,
2141 cache_position=cache_position,
2142 num_logits_to_keep=num_logits_to_keep,
2143 )
2145 return outputs
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:1935, in MllamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **loss_kwargs)
1932 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1934 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
-> 1935 outputs = self.model(
1936 input_ids=input_ids,
1937 cross_attention_states=cross_attention_states,
1938 attention_mask=attention_mask,
1939 position_ids=position_ids,
1940 cross_attention_mask=cross_attention_mask,
1941 full_text_row_masked_out_mask=full_text_row_masked_out_mask,
1942 past_key_values=past_key_values,
1943 inputs_embeds=inputs_embeds,
1944 use_cache=use_cache,
1945 output_attentions=output_attentions,
1946 output_hidden_states=output_hidden_states,
1947 return_dict=return_dict,
1948 cache_position=cache_position,
1949 )
1951 hidden_states = outputs[0]
1952 logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]).float()
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:1797, in MllamaTextModel.forward(self, input_ids, attention_mask, position_ids, cross_attention_states, cross_attention_mask, full_text_row_masked_out_mask, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1782 layer_outputs = self._gradient_checkpointing_func(
1783 decoder_layer.__call__,
1784 hidden_states,
(...)
1794 position_embeddings,
1795 )
1796 else:
-> 1797 layer_outputs = decoder_layer(
1798 hidden_states,
1799 cross_attention_states=cross_attention_states,
1800 cross_attention_mask=cross_attention_mask,
1801 attention_mask=causal_mask,
1802 full_text_row_masked_out_mask=full_text_row_masked_out_mask,
1803 position_ids=position_ids,
1804 past_key_value=past_key_values,
1805 output_attentions=output_attentions,
1806 use_cache=use_cache,
1807 cache_position=cache_position,
1808 position_embeddings=position_embeddings,
1809 )
1811 hidden_states = layer_outputs[0]
1813 if use_cache:
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:951, in MllamaCrossAttentionDecoderLayer.forward(self, hidden_states, cross_attention_states, cross_attention_mask, attention_mask, full_text_row_masked_out_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings)
948 residual = hidden_states
949 hidden_states = self.input_layernorm(hidden_states)
--> 951 hidden_states, attn_weights, past_key_value = self.cross_attn(
952 hidden_states=hidden_states,
953 attention_mask=cross_attention_mask,
954 cross_attention_states=cross_attention_states,
955 past_key_value=past_key_value,
956 output_attentions=output_attentions,
957 cache_position=cache_position,
958 )
959 hidden_states = residual + self.cross_attn_attn_gate.tanh() * hidden_states
961 residual = hidden_states
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1511, in Module._wrapped_call_impl(self, *args, **kwargs)
1509 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1510 else:
-> 1511 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py:1520, in Module._call_impl(self, *args, **kwargs)
1515 # If we don't have any hooks, we want to skip the rest of the logic in
1516 # this function, and just call forward.
1517 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1518 or _global_backward_pre_hooks or _global_backward_hooks
1519 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1520 return forward_call(*args, **kwargs)
1522 try:
1523 result = None
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module.<locals>.new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
--> 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)
File ~/miniconda3/envs/llm/lib/python3.10/site-packages/transformers/models/mllama/modeling_mllama.py:575, in MllamaTextCrossSdpaAttention.forward(self, hidden_states, cross_attention_states, past_key_value, attention_mask, output_attentions, use_cache, cache_position)
570 key_states, value_states = (
571 past_key_value.key_cache[self.layer_idx],
572 past_key_value.value_cache[self.layer_idx],
573 )
574 else:
--> 575 raise ValueError(
576 "Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!"
577 )
579 key_states = repeat_kv(key_states, self.num_key_value_groups)
580 value_states = repeat_kv(value_states, self.num_key_value_groups)
ValueError: Cross attention layer can't find neither `cross_attn_states` nor cached values for key/values!
` ``