Adapt inputs_merger() Function for Scenarios Without Image Input (#5)
Browse files- Adapt inputs_merger() Function for Scenarios Without Image Input (3c44bd312913db88f1df002aed6040511d748ebb)
Co-authored-by: Folco Bertini <[email protected]>
- modeling_vmistral.py +30 -28
modeling_vmistral.py
CHANGED
@@ -1372,36 +1372,38 @@ class VMistralModel(VMistralPreTrainedModel):
|
|
1372 |
batch_size = input_ids.size(0)
|
1373 |
|
1374 |
if inputs_embeds is not None:
|
1375 |
-
vision_pipeline_output_seq_len = image_hidden_states.shape[1]
|
1376 |
-
vision_hidden_size = image_hidden_states.shape[2]
|
1377 |
new_inputs_embeds = inputs_embeds.clone()
|
1378 |
-
|
1379 |
-
|
1380 |
-
|
1381 |
-
|
1382 |
-
# Get the number of images for
|
1383 |
-
|
1384 |
-
|
1385 |
-
|
1386 |
-
|
1387 |
-
|
1388 |
-
|
1389 |
-
|
1390 |
-
|
1391 |
-
|
1392 |
-
|
1393 |
-
|
1394 |
-
|
1395 |
-
|
1396 |
-
|
1397 |
-
|
1398 |
-
|
1399 |
-
|
1400 |
-
|
1401 |
-
|
1402 |
-
|
|
|
|
|
|
|
|
|
|
|
1403 |
)
|
1404 |
-
)
|
1405 |
|
1406 |
return_dict = {}
|
1407 |
if inputs_embeds is not None:
|
|
|
1372 |
batch_size = input_ids.size(0)
|
1373 |
|
1374 |
if inputs_embeds is not None:
|
|
|
|
|
1375 |
new_inputs_embeds = inputs_embeds.clone()
|
1376 |
+
|
1377 |
+
if image_hidden_states is not None:
|
1378 |
+
vision_pipeline_output_seq_len = image_hidden_states.shape[1]
|
1379 |
+
vision_hidden_size = image_hidden_states.shape[2]
|
1380 |
+
# Get the number of images for each example
|
1381 |
+
num_images = (input_ids == self.image_token_id).sum(dim=-1) // self.image_seq_len
|
1382 |
+
cum_num_images = num_images.cumsum(dim=-1)
|
1383 |
+
for batch_idx in range(batch_size):
|
1384 |
+
# Get the number of images for this particular example
|
1385 |
+
example_num_images = num_images[batch_idx]
|
1386 |
+
# Get the image_hidden_states corresponding to True images for the example, so get rid of the padding images.
|
1387 |
+
start = 0 if batch_idx == 0 else cum_num_images[batch_idx - 1]
|
1388 |
+
end = cum_num_images[batch_idx]
|
1389 |
+
example_true_image_hidden_states = image_hidden_states[start:end]
|
1390 |
+
if (
|
1391 |
+
new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]
|
1392 |
+
!= example_num_images * vision_pipeline_output_seq_len
|
1393 |
+
):
|
1394 |
+
raise ValueError(
|
1395 |
+
"new_inputs_embeds to replace has shape[0]:"
|
1396 |
+
f" {new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id].shape[0]} but"
|
1397 |
+
" should have shape[0]:"
|
1398 |
+
f" {example_num_images}*{vision_pipeline_output_seq_len}={example_num_images * vision_pipeline_output_seq_len} "
|
1399 |
+
)
|
1400 |
+
# Insert the image_hidden_states
|
1401 |
+
new_inputs_embeds[batch_idx][input_ids[batch_idx] == self.image_token_id] = (
|
1402 |
+
example_true_image_hidden_states.view(
|
1403 |
+
example_num_images * vision_pipeline_output_seq_len,
|
1404 |
+
vision_hidden_size,
|
1405 |
+
)
|
1406 |
)
|
|
|
1407 |
|
1408 |
return_dict = {}
|
1409 |
if inputs_embeds is not None:
|