Skip to content

feat: add Qwen 3.5 MoE model support#2007

Closed
samsja wants to merge 2 commits intomainfrom
feature/qwen3-5-moe
Closed

feat: add Qwen 3.5 MoE model support#2007
samsja wants to merge 2 commits intomainfrom
feature/qwen3-5-moe

Conversation

@samsja
Copy link
Member

@samsja samsja commented Mar 10, 2026

Summary

  • Adds custom Qwen 3.5 MoE (GatedDeltaNet + MoE) model implementation with HF weight conversion
  • Integrates EP support (MoE layers auto-detected by apply_ep(), shared experts stay replicated)
  • Adds CP support by patching Qwen3_5MoeGatedFlashAttention in substitute_ring_attn()
  • Supports router replay via routed_experts argument for deterministic expert routing

Test plan

  • Unit tests: forward pass numerical match vs HF, weight roundtrip, router replay, CP patching
  • Multi-GPU: verify EP shards experts correctly across ranks
  • Multi-GPU: verify CP ring attention works on full_attention layers

🤖 Generated with Claude Code


Note

Medium Risk
Introduces a large new custom model implementation (attention + MoE + weight-conversion) and extends global attention monkey-patching, which could impact training correctness and distributed runs for affected models.

Overview
Adds a new custom Qwen3_5Moe model stack (config + modeling + HF↔PrimeRL MoE weight conversion) and registers it with PrimeRL’s AutoModelForCausalLMPrimeRL so impl=custom/auto can load it.

Extends substitute_ring_attn() to also patch Qwen3_5MoeGatedFlashAttention for CP ring-attention, and removes the hard error that previously forced VLM models to use impl='hf'.

Adds GPU unit tests validating PrimeRL vs HF forward/grad closeness, lossless weight conversion roundtrip, routed_experts replay behavior, and the new ring-attn patching hook.

Written by Cursor Bugbot for commit 24ab69e. This will update automatically on new commits. Configure here.

Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 2 potential issues.

Autofix Details

Bugbot Autofix prepared fixes for both issues found in the latest run.

  • ✅ Fixed: Test uses position_ids starting from 1 instead of 0
    • Changed position_ids from torch.arange(1, 101) to torch.arange(0, 100) in both test functions to use standard 0-indexed positions.
  • ✅ Fixed: Gated RMSNorm fallback crashes when gate is None
    • Removed the misleading gate=None default from Qwen3_5MoeRMSNormGated.forward() since gate is always required.

Create PR

Or push these changes by commenting:

@cursor push 798dd3b51a
Preview (798dd3b51a)
diff --git a/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py b/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py
--- a/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py
+++ b/src/prime_rl/trainer/models/qwen3_5_moe/modeling_qwen3_5_moe.py
@@ -87,7 +87,7 @@
         self.weight = nn.Parameter(torch.ones(hidden_size))
         self.variance_epsilon = eps
 
-    def forward(self, hidden_states, gate=None):
+    def forward(self, hidden_states, gate):
         input_dtype = hidden_states.dtype
         hidden_states = hidden_states.to(torch.float32)
         variance = hidden_states.pow(2).mean(-1, keepdim=True)

diff --git a/tests/unit/train/models/test_qwen3_5_moe.py b/tests/unit/train/models/test_qwen3_5_moe.py
--- a/tests/unit/train/models/test_qwen3_5_moe.py
+++ b/tests/unit/train/models/test_qwen3_5_moe.py
@@ -50,7 +50,7 @@
 
     with torch.device("cuda"), default_dtype(torch.float32):
         input_ids = torch.randint(0, hf_model.config.vocab_size, (1, 100))
-        position_ids = torch.arange(1, 101).unsqueeze(0)
+        position_ids = torch.arange(0, 100).unsqueeze(0)
 
     hf_output = hf_model(input_ids, position_ids=position_ids)
     prime_output = prime_model(input_ids, position_ids=position_ids)
@@ -96,7 +96,7 @@
 
     with torch.device("cuda"), default_dtype(torch.float32):
         input_ids = torch.randint(0, prime_model.config.vocab_size, (1, 100))
-        position_ids = torch.arange(1, 101).unsqueeze(0)
+        position_ids = torch.arange(0, 100).unsqueeze(0)
 
     out_normal = prime_model(input_ids, position_ids=position_ids)
This Bugbot Autofix run was free. To enable autofix for future PRs, go to the Cursor dashboard.


with torch.device("cuda"), default_dtype(torch.float32):
input_ids = torch.randint(0, hf_model.config.vocab_size, (1, 100))
position_ids = torch.arange(1, 101).unsqueeze(0)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test uses position_ids starting from 1 instead of 0

Medium Severity

position_ids is constructed as torch.arange(1, 101) (starting from 1) instead of the standard torch.arange(0, 100). This produces positions [1..100] instead of [0..99]. While both models receive the same IDs so the comparison is consistent, this means the test doesn't exercise the standard position encoding start. More critically, if the test were run with flash attention, the cu_seqlens logic in Qwen3_5MoeModel.forward would compute incorrect cumulative sequence lengths because it relies on position_ids == 0 to detect sequence boundaries and expects the first element to be 0.

Additional Locations (1)

Fix in Cursor Fix in Web

hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)
hidden_states = hidden_states * F.silu(gate.to(torch.float32))
return hidden_states.to(input_dtype)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gated RMSNorm fallback crashes when gate is None

Low Severity

Qwen3_5MoeRMSNormGated.forward() declares gate=None as default but unconditionally calls gate.to(torch.float32) on line 96, which would raise an AttributeError if gate is actually None. The default parameter is misleading — this method always requires a gate tensor. While current call sites always provide a gate, the misleading signature could cause errors if reused elsewhere.

Fix in Cursor Fix in Web

Adds a custom Qwen 3.5 MoE (GatedDeltaNet + MoE) implementation with:
- HF <-> PrimeRL weight conversion (fused/unfused expert formats)
- Expert Parallelism support (MoE layers auto-detected by apply_ep)
- Context Parallelism support (ring attention patching for flash attention layers)
- Router replay via routed_experts argument
- Unit tests for forward pass, weight roundtrip, router replay, and CP patching

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@samsja samsja force-pushed the feature/qwen3-5-moe branch from ccae96b to 38e3036 Compare March 12, 2026 05:45
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 1 potential issue.

Bugbot Autofix is OFF. To automatically fix reported issues with cloud agents, enable autofix in the Cursor dashboard.

raise ValueError(
f"VLM models only support impl='hf', but got impl='{config.impl}' (resolved to '{impl_to_use}'). "
f"Set impl='hf' or impl='auto' in your model config."
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed VLM guard silently breaks custom model loading

High Severity

The guard if is_vlm and impl_to_use != "hf": raise ValueError(...) was removed, but the SUPPORTED_VLM_PATTERNS in vlm.py includes "Qwen/Qwen3.5*", which matches text-only Qwen3.5 MoE models (e.g., Qwen/Qwen3.5-MoE-35B-A3B). Because is_vlm is checked first and forces model_cls = AutoModelForImageTextToText, the new custom Qwen3_5MoeForCausalLM implementation can never be reached through get_model() for models with standard HuggingFace names. Previously, the guard would raise an explicit error; now it silently loads the wrong model class.

Fix in Cursor Fix in Web

@samsja samsja closed this Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant