-
Notifications
You must be signed in to change notification settings - Fork 74
TMA warp specialized inner persistent #5681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
05d47c1 to
51e2f1a
Compare
|
Review updated until commit 31318b0 Description
|
| Relevant files | |||||
|---|---|---|---|---|---|
| Enhancement |
| ||||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Warp Specialization Logic
|
51e2f1a to
e6eb0db
Compare
|
!test |
Greptile SummaryThis PR implements warp specialized inner persistent scheduling for TMA-based normalization kernels. The changes refactor the scheduler to support circular buffering with specialized compute warp groups, splitting the original Major changes:
Issues identified in previous threads:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Heuristics as getInnerPersistentHeuristics
participant Dispatcher as scheduleInnerPersistent
participant Setup as setupPersistentSchedule
participant Multiwave as scheduleInnerPersistentMultiwave
participant WarpSpec as scheduleInnerPersistentWarpSpecialized
participant Apply as applyVectorization
User->>Heuristics: Request scheduler params
Note over Heuristics: Calculate bdimx, persistent_batch_size
Heuristics->>Heuristics: Compute n_stages (iter & smem limited)
alt n_stages >= 2 && bdimx == 128
Note over Heuristics: Enable Warp Specialization
Heuristics->>Heuristics: Set n_grouped_rows = 2
Heuristics->>Heuristics: Set bdimy = n_compute_warp_groups (2)
alt bdimy > 1 (TIDy specialization)
Heuristics->>Heuristics: bdimy += 1 (add TMA warp)
Heuristics->>Heuristics: Calculate register sharing
else TIDx specialization
Heuristics->>Heuristics: bdimx += kWarpSpecializationPaddedThreads
end
Heuristics->>Heuristics: Set circular_buffer_options
end
Heuristics-->>User: Return InnerNormTmaParams
User->>Dispatcher: scheduleInnerPersistent(fusion, params)
alt circular_buffer_options enabled
Dispatcher->>WarpSpec: scheduleInnerPersistentWarpSpecialized
WarpSpec->>Setup: setupPersistentSchedule
Setup->>Setup: Project persistent buffers
Setup->>Setup: Cache inputs (TMA & LDG)
Setup->>Setup: Cache outputs
Setup-->>WarpSpec: Return ScheduleSetupResult
WarpSpec->>WarpSpec: Split iteration domain (Group, TIDy, BIDx)
WarpSpec->>WarpSpec: Transform reduction domain (b, us, x, v)
WarpSpec->>WarpSpec: rFactor for 2-stage reduction
WarpSpec->>WarpSpec: Propagate transformations
WarpSpec->>WarpSpec: Apply Group parallelization
WarpSpec->>Apply: applyVectorization
WarpSpec->>WarpSpec: Apply circular buffering to TMA loads
else multiwave mode
Dispatcher->>Multiwave: scheduleInnerPersistentMultiwave
Multiwave->>Setup: setupPersistentSchedule
Setup-->>Multiwave: Return ScheduleSetupResult
Multiwave->>Multiwave: Schedule TMA loads (BIDx, Bulk)
Multiwave->>Multiwave: Transform reduction domain
Multiwave->>Multiwave: rFactor for 2-stage reduction
Multiwave->>Apply: applyVectorization
end
Apply->>Apply: Vectorize LDG loads
Apply->>Apply: Vectorize output writes
Apply->>Apply: Vectorize smem-to-reg loads
Dispatcher-->>User: Scheduled fusion ready
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additional Comments (1)
-
csrc/scheduler/normalization_inner_tma.cpp, line 23-29 (link)syntax: Duplicate includes - lines 23-25 are repeated at lines 27-29.
6 files reviewed, 2 comments
5fdd761 to
a53a37c
Compare
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, no comments
|
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 2 comments
| int64_t total_threads = bdimx * bdimy * bdimz; | ||
| if (total_threads > 256) { | ||
| int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads); | ||
| int64_t computation_threads = | ||
| total_threads - kWarpSpecializationPaddedThreads; | ||
| ws.num_registers = scheduler_utils::getRegisterSharing( | ||
| reg_per_thread, | ||
| computation_threads, | ||
| kWarpSpecializationPaddedThreads); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Register sharing calculation includes TMA warp in total_threads (e.g., 128 * 3 * 1 = 384), but computation_threads should exclude the full TMA warp dimension. For TIDy specialization, this should be bdimx * (bdimy - 1) = 128 * 2 = 256 computation threads, matching the pattern in normalization_inner_outer_tma_ws.cpp:286-292.
| int64_t total_threads = bdimx * bdimy * bdimz; | |
| if (total_threads > 256) { | |
| int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads); | |
| int64_t computation_threads = | |
| total_threads - kWarpSpecializationPaddedThreads; | |
| ws.num_registers = scheduler_utils::getRegisterSharing( | |
| reg_per_thread, | |
| computation_threads, | |
| kWarpSpecializationPaddedThreads); | |
| int64_t total_threads = bdimx * bdimy * bdimz; | |
| if (total_threads > 256) { | |
| int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads); | |
| int64_t computation_threads = bdimx * (bdimy - 1); | |
| ws.num_registers = scheduler_utils::getRegisterSharing( | |
| reg_per_thread, | |
| computation_threads, | |
| kWarpSpecializationPaddedThreads); | |
| } |
| if (n_stages >= 2 && bdimx == 128) { | ||
| gdimx = sm_count; | ||
| bdimx = 128; // 4 warps per warp group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The condition checks bdimx == 128 before entering the block, but bdimx was set on line 69 to std::min(4 * warp_size, after_vect) which may not be 128. Then line 98 reassigns it to 128. Either the condition should be bdimx >= 128 or should check a different variable. Is the intent to only enable warp specialization when the initial bdimx heuristic is exactly 128, or should it work for other values?
|
!test |
| const int64_t n_stages = std::min(smem_limited_stages, iter_limited_stages); | ||
| if (n_stages >= 2 && bdimx == 128) { | ||
| gdimx = sm_count; | ||
| bdimx = 128; // 4 warps per warp group |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if condition already checks that bdimx == 128, so this is redundant.
rdspring1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
void scheduleInnerPersistent(Fusion* fusion, const InnerNormTmaParams* params) is 290 lines.
Is it necessary to add warp specialization to the same function?
There are several if-statements for each schedule decision. It is difficult to reason about the schedule holistically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
|
!test |
This reverts commit 005f500.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
|
!test |
Greptile's behavior is changing!From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
|
!test |
regs sharing utils is done in #5679