Skip to content

Hybrid CP local_cp_size=1 can shard RoPE with the static CP group #5352

@dbfancier

Description

@dbfancier

Describe the bug

In Hybrid CP + SFTDataset, short sub-samples can be scheduled with local_cp_size=1. In that case get_batch_on_this_hybrid_cp_rank() does not create a dynamic cp_group and PackedSeqParams.cp_group is None.

TEDotProductAttention correctly treats cp_group=None with local_cp_size=1 as dynamically disabling CP, but GPT RoPE/Yarn/MRoPE passes cp_group=None into the rotary embedding module. RotaryEmbedding.forward() then falls back to its static self.cp_group and can slice the positional embedding with the global/static CP group.

This can make local_cp_size=1 samples use a sharded RoPE even though the token/hidden states are not CP-sharded.

Steps/Code to reproduce bug

Generate a sub-sample with a length less than max_seqlen_per_dp_cp_rank

Expected behavior

When PackedSeqParams.local_cp_size == 1 and cp_group is None, rotary embeddings should not fall back to the static CP group.

Additional context

Hybrid CP SFT packed samples with short sub-samples may get incorrect positional embedding slicing, causing shape mismatch or incorrect RoPE application.

Metadata

Metadata

Assignees

Labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions