fix: handle mm_token_type_ids in collator and packing tests#10397
fix: handle mm_token_type_ids in collator and packing tests#10397markmochi200 wants to merge 3 commits into
Conversation
There was a problem hiding this comment.
Code Review
This pull request improves the handling of multi-modal token type IDs in the data collator, specifically for models like Qwen2VL and Gemma4. Key changes include robust extraction of mm_token_type_ids, updated slicing logic for packed sequences, and improved padding alignment. The test suite was also updated to reflect these changes. Review feedback recommends using dictionary comprehensions for feature slicing to enhance code maintainability and conciseness.
|
@Kuangdd01 Please address this PR asap. Thanks |
|
Thanks! please resolve these conflicts. |
| if "mm_token_type_ids" in inspect.signature(self.get_rope_func).parameters: | ||
| image_token_id = getattr(self.model.config, "image_token_id", None) | ||
| video_token_id = getattr(self.model.config, "video_token_id", None) | ||
| if image_token_id is not None or video_token_id is not None: | ||
| mm_token_type_ids = torch.zeros_like(features["input_ids"]) | ||
| if image_token_id is not None: | ||
| mm_token_type_ids[features["input_ids"] == image_token_id] = 1 | ||
| if video_token_id is not None: | ||
| mm_token_type_ids[features["input_ids"] == video_token_id] = 2 | ||
| rope_index_kwargs["mm_token_type_ids"] = mm_token_type_ids |
There was a problem hiding this comment.
why we remove this if-condition block?
| elif model_type == "gemma4": | ||
| # Gemma 4 text-only batches still require the field. | ||
| features["mm_token_type_ids"] = torch.zeros_like(features["input_ids"]) | ||
|
|
||
| # Keep token_type_ids present as well for Gemma 4 text-only robustness. | ||
| if model_type == "gemma4" and "token_type_ids" not in features: | ||
| features["token_type_ids"] = torch.zeros_like(features["input_ids"]) |
There was a problem hiding this comment.
feel confused about these
There was a problem hiding this comment.
Feel free to train gemma4-31b with a text-only JSONL file on pt stage and you'll reproduce the issue
There was a problem hiding this comment.
oh! Just remembered that huggingface/transformers#45454 may fix the issue in model forwarding?
What does this PR do?
Fixes an issue where
mm_token_type_idsis missing during training or RoPE computation with newer multimodal models (e.g., Gemma 4, Qwen2VL), causing runtime errors such as:Changes
mm_token_type_idsfrom processor outputs in the collatormm_token_type_idsfor Gemma 4 text-only batches when missingmm_token_type_idsthrough packed RoPE position-id computationmm_token_type_idswhen callingget_rope_index()directlyBefore submitting