Skip to content

fix: handle mm_token_type_ids in collator and packing tests#10397

Open
markmochi200 wants to merge 3 commits into
hiyouga:mainfrom
markmochi200:dev_fix_gemma4_mm_token_type_ids
Open

fix: handle mm_token_type_ids in collator and packing tests#10397
markmochi200 wants to merge 3 commits into
hiyouga:mainfrom
markmochi200:dev_fix_gemma4_mm_token_type_ids

Conversation

@markmochi200
Copy link
Copy Markdown

@markmochi200 markmochi200 commented Apr 16, 2026

What does this PR do?

Fixes an issue where mm_token_type_ids is missing during training or RoPE computation with newer multimodal models (e.g., Gemma 4, Qwen2VL), causing runtime errors such as:

mm_token_type_ids is required as a model input when training

Changes

  • preserve and correctly pad mm_token_type_ids from processor outputs in the collator
  • synthesize zero mm_token_type_ids for Gemma 4 text-only batches when missing
  • propagate mm_token_type_ids through packed RoPE position-id computation
  • fix packed per-sample slicing for RoPE computation
  • update the packing test helper to include mm_token_type_ids when calling get_rope_index() directly

Before submitting

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request improves the handling of multi-modal token type IDs in the data collator, specifically for models like Qwen2VL and Gemma4. Key changes include robust extraction of mm_token_type_ids, updated slicing logic for packed sequences, and improved padding alignment. The test suite was also updated to reflect these changes. Review feedback recommends using dictionary comprehensions for feature slicing to enhance code maintainability and conciseness.

Comment thread src/llamafactory/data/collator.py
Comment thread src/llamafactory/data/collator.py
@markmochi200
Copy link
Copy Markdown
Author

@Kuangdd01 Please address this PR asap. Thanks

@Kuangdd01
Copy link
Copy Markdown
Collaborator

Thanks! please resolve these conflicts.

@hiyouga hiyouga requested a review from Kuangdd01 May 6, 2026 16:35
Comment on lines -176 to -185
if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters:
image_token_id = getattr(self.model.config, "image_token_id", None)
video_token_id = getattr(self.model.config, "video_token_id", None)
if image_token_id is not None or video_token_id is not None:
mm_token_type_ids = torch.zeros_like(features["input_ids"])
if image_token_id is not None:
mm_token_type_ids[features["input_ids"] == image_token_id] = 1
if video_token_id is not None:
mm_token_type_ids[features["input_ids"] == video_token_id] = 2
rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we remove this if-condition block?

Comment on lines +471 to +477
elif model_type == "gemma4":
# Gemma 4 text-only batches still require the field.
features["mm_token_type_ids"] = torch.zeros_like(features["input_ids"])

# Keep token_type_ids present as well for Gemma 4 text-only robustness.
if model_type == "gemma4" and "token_type_ids" not in features:
features["token_type_ids"] = torch.zeros_like(features["input_ids"])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel confused about these

Copy link
Copy Markdown
Author

@markmochi200 markmochi200 May 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to train gemma4-31b with a text-only JSONL file on pt stage and you'll reproduce the issue

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh! Just remembered that huggingface/transformers#45454 may fix the issue in model forwarding?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pending This problem is yet to be addressed

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants