Skip to content

Conversation

@h-guo18
Copy link
Contributor

@h-guo18 h-guo18 commented Jan 8, 2026

What does this PR do?

Type of change: New Feature

Overview:

  • Supported Context Parallel by patching torch ring attention;
  • Require following libirary version for stable cp:
    • torch2.8.0
    • transformers5.0.0
    • accelrate1.12.0
  • Move to FSDP2
  • Removed unused arguments in training script (--multi_gpu, fsdp_wrap_layer)

Usage

./launch_train.sh --model $MODEL \
            --output_dir $OUTPUT_DIR \  
            --data $DATA \
            --num_epochs 0.1 \
            --train_bs 1 \
            --eagle_config eagle_config.json \
            --training_seq_len 1024 \
            --cp_size 2   #newly added

Testing

  • SDPA level correctness: tested TTT attention with/without CP, diff < 1%
=== Compare context-parallel (CP) outputs and grads with non-CP ===
Forward output comparison (CP vs Non-CP):
  Absolute diff (adiff) cp_out vs out: 0.001953125
  Relative diff (rdiff) cp_out vs out: 0.00182342529296875
WQ (query proj) grad comparison (CP vs Non-CP):
  Absolute diff (adiff) cp_wq_grad vs wq_grad: 0.0078125
  Relative diff (rdiff) cp_wq_grad vs wq_grad: 0.00347900390625
WK (key proj) grad comparison (CP vs Non-CP):
  Absolute diff (adiff) cp_wk_grad vs wk_grad: 0.0078125
  Relative diff (rdiff) cp_wk_grad vs wk_grad: 0.002471923828125
WV (value proj) grad comparison (CP vs Non-CP):
  Absolute diff (adiff) cp_wv_grad vs wv_grad: 0.25
  Relative diff (rdiff) cp_wv_grad vs wv_grad: 0.0069580078125
==============================================================
  • E2E Training Acc
    (Llama3.1-8B, Unsynthesized magpie)
image
  • Peak Mem Reserved
    (llama3.1-8B, 8xH100, train_length=4k)

    cp_size max_memory_allocated(MB) max_memory_reserved (MB)
    1 65040.20 79018.00
    2 50409.17 73098.00
    4 45120.92 72052.00
    8 38882.12 66484.00
  • Max Training Length test
    (llama3.1-8B, H100)

    cp_size 6k 12k 24k 48k
    1 OOM OOM OOM
    2 OOM OOM
    4 OOM
    8

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added context parallelism (CP) and data parallelism shard size configuration parameters to training arguments.
  • Enhancements

    • Improved TTT attention masking support for speculative decoding workflows.
    • Enhanced training launch script with improved parallelism configuration handling.
  • Chores

    • Updated core dependencies: torch, transformers, accelerate, and wandb.
    • Added FSDP configuration file for distributed training setup.

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 8, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@h-guo18 h-guo18 self-assigned this Jan 8, 2026
@copy-pr-bot
Copy link

copy-pr-bot bot commented Jan 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@h-guo18 h-guo18 marked this pull request as ready for review January 9, 2026 23:42
@h-guo18 h-guo18 requested a review from a team as a code owner January 9, 2026 23:42
Signed-off-by: h-guo18 <[email protected]>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 21, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Introduces TTT (Temporal To True) attention masking support for context parallelism in speculative decoding. Adds attention masking computation and ring attention patching utilities. Updates training pipeline to conditionally apply CP-TTT patches when context parallelism is enabled. Refactors cache initialization and loss masking alignment.

Changes

Cohort / File(s) Summary
TTT Masking Infrastructure
examples/speculative_decoding/eagle_utils.py, modelopt/torch/speculative/utils.py
Introduces TTT attention mask generation (_compute_ttt_attention_mask) and ring attention patching (get_patched_templated_ring_attn, patch_ring_attention_for_ttt). Adds utility functions get_ttt_msk_func and context manager enable_cp_ttt_patch for conditional TTT mask application with CUDNN backend.
Training Configuration & Arguments
examples/speculative_decoding/launch_train.sh, examples/speculative_decoding/main.py, examples/speculative_decoding/fsdp_config.json
Adds CP and DP shard size arguments to training pipeline. New TrainingArguments fields for cp_size and dp_shard_size. Conditional initialization of ParallelismConfig and patching trigger when cp_size > 1. New FSDP v2 configuration file.
Speculative Decoding Model Logic
modelopt/torch/speculative/plugins/transformers.py
Refactors TTT mask construction to use external get_ttt_msk_func utility. Updates cache initialization from legacy DynamicCache.from_legacy_cache to constructor-based approach. Adds loss masking alignment to match eagle logits shape. Introduces ENABLE_CP_TTT_PATCH and CACHED_SHARD_TTT_MASKS module constants.
Dependencies
examples/speculative_decoding/requirements.txt
Replaces flash-attn, openai, py7zr, sentencepiece, tensorboardX with pinned versions of accelerate (1.12.0), torch (2.8.0), transformers (5.0.0rc1), and wandb.

Sequence Diagram(s)

sequenceDiagram
    participant Training as Training<br/>Pipeline
    participant CP as CP Config<br/>(cp_size > 1)
    participant Patch as Ring Attention<br/>Patcher
    participant RingAttn as Ring Attention<br/>Module
    participant TTT as TTT Mask<br/>Generator
    participant Loss as Loss<br/>Computation

    Training->>CP: check cp_size > 1
    activate CP
    alt cp_size > 1
        CP->>Patch: call patch_ring_attention_for_ttt()
        activate Patch
        Patch->>Patch: extract rank, size, query, key, ttt_step from frame
        Patch->>RingAttn: replace _templated_ring_attention with patched version
        Patch->>RingAttn: patch _SDPAMerger.step to skip benign shards
        Patch->>RingAttn: disable CP load balancing
        deactivate Patch
        
        Training->>RingAttn: forward pass (attention computation)
        activate RingAttn
        RingAttn->>TTT: inject TTT mask into attention bias
        activate TTT
        TTT->>TTT: compute composite mask with kv_idx and ttt_step
        TTT-->>RingAttn: return ttt_attention_mask
        deactivate TTT
        RingAttn-->>Training: attention output with TTT masking
        deactivate RingAttn
    else cp_size == 1
        CP->>Training: skip patching
    end
    deactivate CP

    Training->>Loss: forward pass (Eagle model)
    activate Loss
    Loss->>Loss: align loss_mask to eagle_logits shape
    Loss->>Loss: compute softmax cross-entropy with aligned masks
    Loss-->>Training: loss value
    deactivate Loss
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title 'Feat: Context Parallel for Eagle3 Training' directly and clearly summarizes the main change: adding context parallelism support for Eagle3 training, which is the primary objective of this PR.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🤖 Fix all issues with AI agents
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 661-662: The error message string in the conditional that checks
patch_enbabled and original_op against
torch.ops.aten._scaled_dot_product_cudnn_attention contains a typo ("cuddn");
update the ValueError text to use the correct spelling "cudnn" so the raised
message reads something like "CP TTT only supports cudnn attention now. Got:
{original_op}" while keeping the same condition and variables (patch_enbabled,
original_op, torch.ops.aten._scaled_dot_product_cudnn_attention).
- Around line 668-688: The patched_op wrapper uses
inspect.currentframe()/frame.f_back and populates rank,size,query,key,i,ttt_step
inside a try/except but then continues even if inspection failed, which can lead
to NameError; update patched_op to validate that inspect.currentframe() and
frame.f_back are not None and that f_back.f_locals contains the expected keys
before using them (inspect.currentframe, frame.f_back, f_back.f_locals, keys
"rank","size","query","key","i"); if any check fails, either re-raise the caught
exception or return/raise a clear error early so the function does not proceed
to call _get_sharded_ttt_msk or original_op with undefined variables; ensure
ttt_step is computed only after query/key are present and preserve the
original_op call path when inspection succeeds.
- Around line 700-701: Fix the typo in the inline comment above the config
assignment: change "permenantly" to "permanently" in the comment that precedes
the
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance
= False line so the comment reads "So need to be done permanently before
accelerate/hf trainer init."
- Around line 580-598: Delete the dead _compute_ttt_attention_mask function in
eagle_utils.py: remove the entire function definition (def
_compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) ->
torch.Tensor:) because it is unused and its docstring is misleading; ensure
there are no remaining references to this symbol and rely on the existing
implementation in transformers.py (the plugin’s _compute_ttt_attention_mask)
which handles flex_attention correctly.
- Around line 720-725: The function patched_sdpa_merger_step has an incorrect
return annotation: change its signature from "-> torch.Tensor" to "-> None"
because it performs in-place mutations and returns None (match original
_SDPAMerger.step); update the annotation on patched_sdpa_merger_step and ensure
the call to original_sdpa_merger_step is used only for its side effects. Also
optionally review the lse.sum() <= 0 check in patched_sdpa_merger_step to
confirm it correctly identifies blank shards (log-sum-exp can be negative) and
adjust the condition if needed.

In `@examples/speculative_decoding/requirements.txt`:
- Around line 1-4: The requirements file currently leaves the dependency "wandb"
unpinned; update the requirements.txt entry "wandb" to a specific, tested
version (e.g., replace the bare token "wandb" with a pinned version like
"wandb==<chosen_version>") to ensure reproducible installs; choose and document
a stable release you validated against the existing pinned packages and commit
the updated line.

In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 65-66: CACHED_SHARD_TTT_MASKS is defined but unused; either delete
the constant or add a clarifying comment and intended usage so it isn't flagged
as dead code—update the module-level definition near ENABLE_CP_TTT_PATCH by
removing the CACHED_SHARD_TTT_MASKS = {} line if it's not needed, or replace it
with a short doc comment (e.g., explaining expected keys/values and when it
should be populated) and keep the name as-is so future readers know its purpose;
ensure no other code references are required by also grepping for
CACHED_SHARD_TTT_MASKS before removal.
🧹 Nitpick comments (6)
modelopt/torch/speculative/utils.py (1)

462-470: Consider exception safety in the context manager.

The context manager correctly restores the flag in the finally block. However, the sdpa_kernel context manager wraps the try/yield/finally, which means if sdpa_kernel raises during entry, the flag might already be set to True but won't be reset.

Consider this safer ordering:

♻️ Suggested improvement
 `@contextlib.contextmanager`
 def enable_cp_ttt_patch():
     """Context manager to enable CP TTT patch."""
-    modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True
-    with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
-        try:
-            yield
-        finally:
-            modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False
+    try:
+        modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True
+        with sdpa_kernel(SDPBackend.CUDNN_ATTENTION):
+            yield
+    finally:
+        modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False
modelopt/torch/speculative/plugins/transformers.py (1)

905-918: Consider making enable_cp_ttt_patch() conditional on CP being active.

The enable_cp_ttt_patch() context manager is applied unconditionally during training, but the TODO on line 905 suggests the mask isn't used during CP training. This forces CUDNN_ATTENTION backend even when cp_size=1, which may have unintended performance characteristics.

Consider wrapping conditionally:

♻️ Suggested approach
+            ctx = enable_cp_ttt_patch() if ENABLE_CP_TTT_PATCH else contextlib.nullcontext()
-            with enable_cp_ttt_patch():
+            with ctx:
                 _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward(

This would require importing contextlib and checking ENABLE_CP_TTT_PATCH status.

examples/speculative_decoding/launch_train.sh (2)

97-104: Inconsistent argument naming: --dp_size vs dp_shard_size.

The CLI argument is --dp_size but it sets DP_SHARD_SIZE and the Python code uses dp_shard_size. Consider aligning the CLI argument name for consistency:

♻️ Suggested fix
-    --dp_size*)
+    --dp_shard_size*)
       if [[ "$1" != *=* ]]; then shift; fi
       DP_SHARD_SIZE="${1#*=}"
       ;;

139-140: Add validation for CP_SIZE and GPU_COUNT relationship.

The calculation DP_SHARD_SIZE=$((GPU_COUNT/CP_SIZE)) could result in 0 if CP_SIZE > GPU_COUNT, or leave resources unused if GPU_COUNT is not evenly divisible by CP_SIZE.

Consider adding validation:

♻️ Suggested validation
 CP_SIZE=${CP_SIZE:-1}
+if [[ $CP_SIZE -gt $GPU_COUNT ]]; then
+  echo "Error: cp_size ($CP_SIZE) cannot exceed GPU count ($GPU_COUNT)"
+  exit 1
+fi
+if [[ $((GPU_COUNT % CP_SIZE)) -ne 0 ]]; then
+  echo "Warning: GPU count ($GPU_COUNT) is not evenly divisible by cp_size ($CP_SIZE)"
+fi
 DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}
examples/speculative_decoding/main.py (1)

141-147: ParallelismConfig API usage is correct; existing comment sufficiently tracks the workaround.

The ParallelismConfig correctly accepts both cp_size and dp_shard_size parameters (verified against accelerate documentation). The sp_backend = None workaround for accelerate 1.12.0 is appropriate and already has a clear comment indicating removal after upgrade to 1.13.0.

Optionally, consider adding a linked TODO (e.g., with a GitHub issue or PR reference) to the existing comment to streamline tracking of the deprecation.

examples/speculative_decoding/eagle_utils.py (1)

652-692: Frame inspection is fragile and tightly coupled to torch internals.

The approach of using inspect.currentframe().f_back.f_locals to extract variables from PyTorch's internal _templated_ring_attention implementation is inherently fragile. Any change to variable names or control flow in PyTorch's implementation will silently break this code.

Consider:

  1. Adding a comment documenting which PyTorch version this was tested against (torch 2.8.0 per PR).
  2. Adding a runtime assertion or version check to fail fast if the expected variables aren't found.

@codecov
Copy link

codecov bot commented Jan 21, 2026

Codecov Report

❌ Patch coverage is 29.41176% with 12 lines in your changes missing coverage. Please review.
✅ Project coverage is 74.18%. Comparing base (21a4010) to head (a9d58a9).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/torch/speculative/utils.py 29.41% 12 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #745      +/-   ##
==========================================
- Coverage   74.19%   74.18%   -0.01%     
==========================================
  Files         192      192              
  Lines       19238    19255      +17     
==========================================
+ Hits        14273    14284      +11     
- Misses       4965     4971       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: h-guo18 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants