-
Notifications
You must be signed in to change notification settings - Fork 241
Feat: Context Parallel for Eagle3 Training #745
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
e8ca86d to
b5e0b07
Compare
Signed-off-by: h-guo18 <[email protected]>
b5e0b07 to
ea36677
Compare
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughIntroduces TTT (Temporal To True) attention masking support for context parallelism in speculative decoding. Adds attention masking computation and ring attention patching utilities. Updates training pipeline to conditionally apply CP-TTT patches when context parallelism is enabled. Refactors cache initialization and loss masking alignment. Changes
Sequence Diagram(s)sequenceDiagram
participant Training as Training<br/>Pipeline
participant CP as CP Config<br/>(cp_size > 1)
participant Patch as Ring Attention<br/>Patcher
participant RingAttn as Ring Attention<br/>Module
participant TTT as TTT Mask<br/>Generator
participant Loss as Loss<br/>Computation
Training->>CP: check cp_size > 1
activate CP
alt cp_size > 1
CP->>Patch: call patch_ring_attention_for_ttt()
activate Patch
Patch->>Patch: extract rank, size, query, key, ttt_step from frame
Patch->>RingAttn: replace _templated_ring_attention with patched version
Patch->>RingAttn: patch _SDPAMerger.step to skip benign shards
Patch->>RingAttn: disable CP load balancing
deactivate Patch
Training->>RingAttn: forward pass (attention computation)
activate RingAttn
RingAttn->>TTT: inject TTT mask into attention bias
activate TTT
TTT->>TTT: compute composite mask with kv_idx and ttt_step
TTT-->>RingAttn: return ttt_attention_mask
deactivate TTT
RingAttn-->>Training: attention output with TTT masking
deactivate RingAttn
else cp_size == 1
CP->>Training: skip patching
end
deactivate CP
Training->>Loss: forward pass (Eagle model)
activate Loss
Loss->>Loss: align loss_mask to eagle_logits shape
Loss->>Loss: compute softmax cross-entropy with aligned masks
Loss-->>Training: loss value
deactivate Loss
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 7
🤖 Fix all issues with AI agents
In `@examples/speculative_decoding/eagle_utils.py`:
- Around line 661-662: The error message string in the conditional that checks
patch_enbabled and original_op against
torch.ops.aten._scaled_dot_product_cudnn_attention contains a typo ("cuddn");
update the ValueError text to use the correct spelling "cudnn" so the raised
message reads something like "CP TTT only supports cudnn attention now. Got:
{original_op}" while keeping the same condition and variables (patch_enbabled,
original_op, torch.ops.aten._scaled_dot_product_cudnn_attention).
- Around line 668-688: The patched_op wrapper uses
inspect.currentframe()/frame.f_back and populates rank,size,query,key,i,ttt_step
inside a try/except but then continues even if inspection failed, which can lead
to NameError; update patched_op to validate that inspect.currentframe() and
frame.f_back are not None and that f_back.f_locals contains the expected keys
before using them (inspect.currentframe, frame.f_back, f_back.f_locals, keys
"rank","size","query","key","i"); if any check fails, either re-raise the caught
exception or return/raise a clear error early so the function does not proceed
to call _get_sharded_ttt_msk or original_op with undefined variables; ensure
ttt_step is computed only after query/key are present and preserve the
original_op call path when inspection succeeds.
- Around line 700-701: Fix the typo in the inline comment above the config
assignment: change "permenantly" to "permanently" in the comment that precedes
the
torch.distributed.tensor.experimental._attention._cp_options.enable_load_balance
= False line so the comment reads "So need to be done permanently before
accelerate/hf trainer init."
- Around line 580-598: Delete the dead _compute_ttt_attention_mask function in
eagle_utils.py: remove the entire function definition (def
_compute_ttt_attention_mask(batch_size, seq_length, ttt_step, dtype) ->
torch.Tensor:) because it is unused and its docstring is misleading; ensure
there are no remaining references to this symbol and rely on the existing
implementation in transformers.py (the plugin’s _compute_ttt_attention_mask)
which handles flex_attention correctly.
- Around line 720-725: The function patched_sdpa_merger_step has an incorrect
return annotation: change its signature from "-> torch.Tensor" to "-> None"
because it performs in-place mutations and returns None (match original
_SDPAMerger.step); update the annotation on patched_sdpa_merger_step and ensure
the call to original_sdpa_merger_step is used only for its side effects. Also
optionally review the lse.sum() <= 0 check in patched_sdpa_merger_step to
confirm it correctly identifies blank shards (log-sum-exp can be negative) and
adjust the condition if needed.
In `@examples/speculative_decoding/requirements.txt`:
- Around line 1-4: The requirements file currently leaves the dependency "wandb"
unpinned; update the requirements.txt entry "wandb" to a specific, tested
version (e.g., replace the bare token "wandb" with a pinned version like
"wandb==<chosen_version>") to ensure reproducible installs; choose and document
a stable release you validated against the existing pinned packages and commit
the updated line.
In `@modelopt/torch/speculative/plugins/transformers.py`:
- Around line 65-66: CACHED_SHARD_TTT_MASKS is defined but unused; either delete
the constant or add a clarifying comment and intended usage so it isn't flagged
as dead code—update the module-level definition near ENABLE_CP_TTT_PATCH by
removing the CACHED_SHARD_TTT_MASKS = {} line if it's not needed, or replace it
with a short doc comment (e.g., explaining expected keys/values and when it
should be populated) and keep the name as-is so future readers know its purpose;
ensure no other code references are required by also grepping for
CACHED_SHARD_TTT_MASKS before removal.
🧹 Nitpick comments (6)
modelopt/torch/speculative/utils.py (1)
462-470: Consider exception safety in the context manager.The context manager correctly restores the flag in the
finallyblock. However, thesdpa_kernelcontext manager wraps thetry/yield/finally, which means ifsdpa_kernelraises during entry, the flag might already be set toTruebut won't be reset.Consider this safer ordering:
♻️ Suggested improvement
`@contextlib.contextmanager` def enable_cp_ttt_patch(): """Context manager to enable CP TTT patch.""" - modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True - with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): - try: - yield - finally: - modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False + try: + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True + with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): + yield + finally: + modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = Falsemodelopt/torch/speculative/plugins/transformers.py (1)
905-918: Consider makingenable_cp_ttt_patch()conditional on CP being active.The
enable_cp_ttt_patch()context manager is applied unconditionally during training, but the TODO on line 905 suggests the mask isn't used during CP training. This forces CUDNN_ATTENTION backend even whencp_size=1, which may have unintended performance characteristics.Consider wrapping conditionally:
♻️ Suggested approach
+ ctx = enable_cp_ttt_patch() if ENABLE_CP_TTT_PATCH else contextlib.nullcontext() - with enable_cp_ttt_patch(): + with ctx: _, eagle_input_hidden_states, eagle_logits, eagle_cache = self._eagle_forward(This would require importing
contextliband checkingENABLE_CP_TTT_PATCHstatus.examples/speculative_decoding/launch_train.sh (2)
97-104: Inconsistent argument naming:--dp_sizevsdp_shard_size.The CLI argument is
--dp_sizebut it setsDP_SHARD_SIZEand the Python code usesdp_shard_size. Consider aligning the CLI argument name for consistency:♻️ Suggested fix
- --dp_size*) + --dp_shard_size*) if [[ "$1" != *=* ]]; then shift; fi DP_SHARD_SIZE="${1#*=}" ;;
139-140: Add validation for CP_SIZE and GPU_COUNT relationship.The calculation
DP_SHARD_SIZE=$((GPU_COUNT/CP_SIZE))could result in 0 ifCP_SIZE > GPU_COUNT, or leave resources unused ifGPU_COUNTis not evenly divisible byCP_SIZE.Consider adding validation:
♻️ Suggested validation
CP_SIZE=${CP_SIZE:-1} +if [[ $CP_SIZE -gt $GPU_COUNT ]]; then + echo "Error: cp_size ($CP_SIZE) cannot exceed GPU count ($GPU_COUNT)" + exit 1 +fi +if [[ $((GPU_COUNT % CP_SIZE)) -ne 0 ]]; then + echo "Warning: GPU count ($GPU_COUNT) is not evenly divisible by cp_size ($CP_SIZE)" +fi DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))}examples/speculative_decoding/main.py (1)
141-147: ParallelismConfig API usage is correct; existing comment sufficiently tracks the workaround.The
ParallelismConfigcorrectly accepts bothcp_sizeanddp_shard_sizeparameters (verified against accelerate documentation). Thesp_backend = Noneworkaround for accelerate 1.12.0 is appropriate and already has a clear comment indicating removal after upgrade to 1.13.0.Optionally, consider adding a linked TODO (e.g., with a GitHub issue or PR reference) to the existing comment to streamline tracking of the deprecation.
examples/speculative_decoding/eagle_utils.py (1)
652-692: Frame inspection is fragile and tightly coupled to torch internals.The approach of using
inspect.currentframe().f_back.f_localsto extract variables from PyTorch's internal_templated_ring_attentionimplementation is inherently fragile. Any change to variable names or control flow in PyTorch's implementation will silently break this code.Consider:
- Adding a comment documenting which PyTorch version this was tested against (torch 2.8.0 per PR).
- Adding a runtime assertion or version check to fail fast if the expected variables aren't found.
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #745 +/- ##
==========================================
- Coverage 74.19% 74.18% -0.01%
==========================================
Files 192 192
Lines 19238 19255 +17
==========================================
+ Hits 14273 14284 +11
- Misses 4965 4971 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Signed-off-by: h-guo18 <[email protected]>
What does this PR do?
Type of change: New Feature
Overview:
--multi_gpu,fsdp_wrap_layer)Usage
Testing
(Llama3.1-8B, Unsynthesized magpie)
Peak Mem Reserved
(llama3.1-8B, 8xH100, train_length=4k)
Max Training Length test
(llama3.1-8B, H100)
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Enhancements
Chores
✏️ Tip: You can customize this high-level summary in your review settings.