Describe the bug
In Hybrid CP + SFTDataset, short sub-samples can be scheduled with local_cp_size=1. In that case get_batch_on_this_hybrid_cp_rank() does not create a dynamic cp_group and PackedSeqParams.cp_group is None.
TEDotProductAttention correctly treats cp_group=None with local_cp_size=1 as dynamically disabling CP, but GPT RoPE/Yarn/MRoPE passes cp_group=None into the rotary embedding module. RotaryEmbedding.forward() then falls back to its static self.cp_group and can slice the positional embedding with the global/static CP group.
This can make local_cp_size=1 samples use a sharded RoPE even though the token/hidden states are not CP-sharded.
Steps/Code to reproduce bug
Generate a sub-sample with a length less than max_seqlen_per_dp_cp_rank
Expected behavior
When PackedSeqParams.local_cp_size == 1 and cp_group is None, rotary embeddings should not fall back to the static CP group.
Additional context
Hybrid CP SFT packed samples with short sub-samples may get incorrect positional embedding slicing, causing shape mismatch or incorrect RoPE application.
Describe the bug
In Hybrid CP + SFTDataset, short sub-samples can be scheduled with local_cp_size=1. In that case get_batch_on_this_hybrid_cp_rank() does not create a dynamic cp_group and PackedSeqParams.cp_group is None.
TEDotProductAttention correctly treats cp_group=None with local_cp_size=1 as dynamically disabling CP, but GPT RoPE/Yarn/MRoPE passes cp_group=None into the rotary embedding module. RotaryEmbedding.forward() then falls back to its static self.cp_group and can slice the positional embedding with the global/static CP group.
This can make local_cp_size=1 samples use a sharded RoPE even though the token/hidden states are not CP-sharded.
Steps/Code to reproduce bug
Generate a sub-sample with a length less than max_seqlen_per_dp_cp_rank
Expected behavior
When PackedSeqParams.local_cp_size == 1 and cp_group is None, rotary embeddings should not fall back to the static CP group.
Additional context
Hybrid CP SFT packed samples with short sub-samples may get incorrect positional embedding slicing, causing shape mismatch or incorrect RoPE application.