Skip to content

Refactor soap#131

Merged
skyw merged 12 commits intomainfrom
skyw/refactor_soap
Mar 17, 2026
Merged

Refactor soap#131
skyw merged 12 commits intomainfrom
skyw/refactor_soap

Conversation

@skyw
Copy link
Contributor

@skyw skyw commented Mar 17, 2026

Remove 1d support and clean up rest of the code.

Store plain tensor in optimizer states instead of list for easy checkpoint (state_dict) handling.

No functionality change. Removed torch.compile over tensordot, which won't be able to do anything anyway

skyw added 3 commits March 16, 2026 17:07
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw requested a review from mkhona-nvidia March 17, 2026 00:36
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 17, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

skyw added 6 commits March 16, 2026 18:44
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw
Copy link
Contributor Author

skyw commented Mar 17, 2026

@greptile

@greptile-apps
Copy link

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR refactors the SOAP optimizer by removing 1D tensor support and replacing list-based state storage (state["GG"], state["Q"]) with individually named tensors (state["L"], state["R"], state["Q_L"], state["Q_R"]). This simplifies checkpoint handling (state_dict) and removes all the empty-tensor guard branches throughout the codebase. A new skip_non_grad_params parameter is added to _init_group to support Megatron-LM distributed checkpointing, where state must be initialized before gradients are available.

Key changes:

  • init_kronecker_factors now accepts a torch.Size instead of a tensor and returns a (L, R) tuple instead of a list
  • update_kronecker_factors simplified to two explicit matmuls (G @ G.T, G.T @ G) instead of a generic tensordot loop
  • @torch.compile was removed from the precondition function, which is called multiple times per step — the intent should be clarified with a comment if intentional
  • Tests updated to reflect renamed state keys and the 2D-only contract; no tests cover the new TypeError paths for invalid input dimensions

Confidence Score: 4/5

  • Safe to merge — the refactoring is logically consistent and well-tested for the 2D case; the main concern is an unexplained performance change.
  • The state-key renaming and 1D removal are clean and the reference-implementation comparison tests provide strong correctness coverage. The only notable concern is the silent removal of @torch.compile from the hot-path precondition function and the absence of tests for the new TypeError error paths.
  • emerging_optimizers/soap/soap.py around the precondition function (missing @torch.compile); tests/test_soap.py for the missing TypeError test coverage.

Important Files Changed

Filename Overview
emerging_optimizers/soap/soap.py Core optimizer refactored to store named flat tensors (L, R, Q_L, Q_R) instead of lists (GG, Q), removed 1D support, added skip_non_grad_params for Megatron-LM checkpointing, and silently dropped @torch.compile from the hot-path precondition function.
emerging_optimizers/soap/soap_utils.py Cleaned up by removing all empty-tensor (numel()==0) guard branches that were only needed for the now-deleted 1D preconditioning path. Logic is unchanged for 2D tensors.
tests/test_soap.py Tests updated to match renamed state keys (L/R/Q_L/Q_R) and the new 2D-only API. The new error-path behavior (TypeError for non-2D tensors) is not explicitly tested.
tests/test_soap_utils.py Removed tests for empty/zero-dimension Kronecker factor edge cases that are no longer reachable with the 2D-only constraint. Remaining tests are unaffected.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[optimizer.step called] --> B[_init_group\nskip_non_grad_params=True]
    B --> C{p.dim != 2?}
    C -- TypeError --> D[raise TypeError]
    C -- ok --> E[Init state\nL, R = zeros\nQ_L, Q_R = eye\nexp_avg, exp_avg_sq = zeros]
    E --> F[For each param with grad]
    F --> G[Build local refs\nkronecker_factor_list = state L, R\neigenbasis_list = state Q_L, Q_R]
    G --> H[update_kronecker_factors\nL = beta*L + 1-beta * G@G.T\nR = beta*R + 1-beta * G.T@G]
    H --> I{Is eigenbasis\nupdate step?}
    I -- No --> J[skip]
    I -- Yes --> K{skip_update via\nadaptive criteria?}
    K -- skip --> J
    K -- update --> L[update_eigenbasis_and_momentum\nreturns new Q_L, Q_R]
    L --> M[state Q_L, Q_R = updated\nRebind local eigenbasis_list]
    M --> N{step >= adam_warmup?}
    J --> N
    N -- No --> O[grad_projected = grad\nprecond_update = adam_update]
    N -- Yes --> P[precondition grad\nvia eigenbasis_list]
    P --> Q[adam_update on projected grad]
    Q --> R[precondition back\nto original basis]
    R --> S[clip update RMS\napply update to param]
    O --> S
    S --> T[step += 1]
Loading

Last reviewed commit: 0433d3c

skyw added 3 commits March 17, 2026 10:30
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
@skyw skyw marked this pull request as ready for review March 17, 2026 17:52
@skyw
Copy link
Contributor Author

skyw commented Mar 17, 2026

/ok to test 0433d3c

@github-actions
Copy link

Test Results

   48 files  ±0     98 suites  ±0   1m 18s ⏱️ -1s
1 008 tests  - 2  1 008 ✅  - 1  0 💤 ±0  0 ❌  - 1 
2 247 runs   - 4  2 247 ✅  - 3  0 💤 ±0  0 ❌  - 1 

Results for commit 0433d3c. ± Comparison against base commit 4acd7c4.

This pull request removes 6 and adds 4 tests. Note that renamed tests count towards both.
__main__.SoapFunctionsTest ‑ test_adam_warmup_steps0 (1)
__main__.SoapFunctionsTest ‑ test_adam_warmup_steps1 (2)
__main__.SoapFunctionsTest ‑ test_adam_warmup_steps2 (3)
__main__.SoapFunctionsTest ‑ test_init_preconditioner_multidim_tensor_shapes
__main__.SoapUtilsTest ‑ test_get_eigenbasis_eigh2 (dims=[64, 0, 32])
__main__.SoapUtilsTest ‑ test_get_eigenbasis_qr_empty_factor
__main__.SoapFunctionsTest ‑ test_adam_warmup_steps_has_ql_qr0 (1)
__main__.SoapFunctionsTest ‑ test_adam_warmup_steps_has_ql_qr1 (2)
__main__.SoapFunctionsTest ‑ test_adam_warmup_steps_has_ql_qr2 (3)
__main__.SoapFunctionsTest ‑ test_init_kronecker_factors_2d_tensor_shapes

@codecov
Copy link

codecov bot commented Mar 17, 2026

Codecov Report

❌ Patch coverage is 95.23810% with 1 line in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
emerging_optimizers/soap/soap.py 95.23% 0 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@skyw skyw merged commit 8278454 into main Mar 17, 2026
17 checks passed
@skyw skyw deleted the skyw/refactor_soap branch March 17, 2026 20:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants