Conversation
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
@greptile |
Greptile SummaryThis PR refactors the SOAP optimizer by removing 1D tensor support and replacing list-based state storage ( Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[optimizer.step called] --> B[_init_group\nskip_non_grad_params=True]
B --> C{p.dim != 2?}
C -- TypeError --> D[raise TypeError]
C -- ok --> E[Init state\nL, R = zeros\nQ_L, Q_R = eye\nexp_avg, exp_avg_sq = zeros]
E --> F[For each param with grad]
F --> G[Build local refs\nkronecker_factor_list = state L, R\neigenbasis_list = state Q_L, Q_R]
G --> H[update_kronecker_factors\nL = beta*L + 1-beta * G@G.T\nR = beta*R + 1-beta * G.T@G]
H --> I{Is eigenbasis\nupdate step?}
I -- No --> J[skip]
I -- Yes --> K{skip_update via\nadaptive criteria?}
K -- skip --> J
K -- update --> L[update_eigenbasis_and_momentum\nreturns new Q_L, Q_R]
L --> M[state Q_L, Q_R = updated\nRebind local eigenbasis_list]
M --> N{step >= adam_warmup?}
J --> N
N -- No --> O[grad_projected = grad\nprecond_update = adam_update]
N -- Yes --> P[precondition grad\nvia eigenbasis_list]
P --> Q[adam_update on projected grad]
Q --> R[precondition back\nto original basis]
R --> S[clip update RMS\napply update to param]
O --> S
S --> T[step += 1]
Last reviewed commit: 0433d3c |
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
Signed-off-by: Hao Wu <skyw@nvidia.com>
|
/ok to test 0433d3c |
Test Results 48 files ±0 98 suites ±0 1m 18s ⏱️ -1s Results for commit 0433d3c. ± Comparison against base commit 4acd7c4. This pull request removes 6 and adds 4 tests. Note that renamed tests count towards both. |
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
Remove 1d support and clean up rest of the code.
Store plain tensor in optimizer states instead of list for easy checkpoint (state_dict) handling.
No functionality change. Removed torch.compile over tensordot, which won't be able to do anything anyway