Skip to content

MoE expert parallelism + sequence parallelism#45408

Open
3outeille wants to merge 2 commits intorefactor-tp-dtensorfrom
moe-sequence-parallel
Open

MoE expert parallelism + sequence parallelism#45408
3outeille wants to merge 2 commits intorefactor-tp-dtensorfrom
moe-sequence-parallel

Conversation

@3outeille
Copy link
Copy Markdown
Member

Summary

  • Extends the TPStyle API (from TP refactor for FSDP + TP integration #45028) with MoE expert parallelism and sequence parallelism support
  • Adds PackedColwiseParallel, MoEExpertsParallel, PrepareModuleInputOutput, _AllReduceBackward custom ParallelStyle subclasses
  • Extends TPStyle with moe_experts, packed_colwise, activation, module, loss_parallel kinds
  • _StridedShard handling in core_model_loading.py for interleaved gate_up_proj weights
  • MoE model configs for mixtral, deepseek_v3, qwen3 with sequence parallelism plans

Part of the distributed training API chain: #44989

Chain: main ← #44989 ← #44083 ← #44974 ← #45028 ← this PR ← orchestration+save PR

Review question

Are the custom ParallelStyle subclasses correct for expert sharding + sequence parallelism?

Test plan

  • Verify MoE expert sharding produces correct DTensor placements
  • Test sequence parallelism with allgather/split hooks
  • Run existing TP mixin tests to ensure no regression

- Add PackedColwiseParallel for fused gate_up_proj weights
- Add MoEExpertsParallel with per-expert DTensor sharding
- Add PrepareModuleInputOutput for SP allgather/split hooks
- Add _AllReduceBackward for MoE routing weight gradients
- Extend TPStyle with moe_experts, packed_colwise, activation, module kinds
- _StridedShard handling in core_model_loading for interleaved weights
- MoE model configs: mixtral, deepseek_v3, qwen3 with SP plans
- DTensor rotary_pos_emb guard for mixtral
@github-actions
Copy link
Copy Markdown
Contributor

[For maintainers] Suggested jobs to run (before merge)

run-slow: deepseek_v3, mixtral, qwen3

@3outeille 3outeille force-pushed the refactor-tp-dtensor branch from 34a5085 to eb428cc Compare April 14, 2026 09:54
@3outeille 3outeille force-pushed the moe-sequence-parallel branch from e04c7d9 to 24ca327 Compare April 14, 2026 09:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant