Skip to content

Conversation

@liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Dec 15, 2025

regs sharing utils is done in #5679

@github-actions
Copy link

github-actions bot commented Dec 15, 2025

Review updated until commit 31318b0

Description

  • Split large scheduleInnerPersistent function into modular helper functions for better maintainability

  • Add warp specialization support with circular buffering for improved TMA performance

  • Rename rows_per_block parameter to n_grouped_rows for clarity

  • Enhance vectorization handling with dedicated helper functions

  • Update test parameters to use makeContigConcreteTensor and adjust batch sizes

Changes walkthrough

Relevant files
Enhancement
normalization_inner_tma.cpp
Modularize scheduler with warp specialization support       

csrc/scheduler/normalization_inner_tma.cpp

  • Split scheduleInnerPersistent into modular functions:
    setupPersistentSchedule, applyVectorization,
    scheduleInnerPersistentMultiwave,
    scheduleInnerPersistentWarpSpecialized
  • Add warp specialization logic with circular buffering, register
    sharing, and launch parameter configuration
  • Implement dispatch logic to choose between multiwave and warp
    specialized execution
  • Enhance vectorization with helper functions for better code
    organization
  • Add ScheduleSetupResult struct to hold common setup results
  • +363/-79
    normalization_inner_tma.h
    Rename rows_per_block to n_grouped_rows parameter               

    csrc/scheduler/normalization_inner_tma.h

  • Rename rows_per_block parameter to n_grouped_rows throughout the
    header
  • Update comparison, toString, and hash methods to use new parameter
    name
  • +4/-4     
    Tests
    test_persistent_buffer.cpp
    Update test parameters and tensor creation                             

    tests/cpp/test_persistent_buffer.cpp

  • Change tensor creation from makeContigTensor to
    makeContigConcreteTensor with explicit dimensions
  • Update test batch size parameters to use deviceSMCount()/2 and 2048
  • +3/-3     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Warp Specialization Logic

    The warp specialization logic (lines 96-136) contains complex conditions for enabling warp specialized execution. The heuristic for determining when to use warp specialization (n_stages >= 2 && bdimx == 128) and the register sharing calculation should be carefully validated to ensure it doesn't cause performance regressions or incorrect behavior on different GPU architectures.

    if (n_stages >= 2 && bdimx == 128) {
      gdimx = sm_count;
      bdimx = 128; // 4 warps per warp group
      bdimy = n_compute_warp_groups;
      bdimz = 1; // warp specialized kernel requires static CTA shape
      params->n_grouped_rows = n_rows_per_compute_warp_group;
      ParallelType ws_pt = bdimy > 1 ? ParallelType::TIDy : ParallelType::TIDx;
      WarpSpecialized ws(ws_pt);
      if (ws_pt == ParallelType::TIDy) {
        bdimy += 1;
        ws.stage_slice_position = 3;
        // Limitation in grouped reduction runtime function
        NVF_ERROR(bdimx == 128, "bdimx must be 128 for TIDy warp specialization");
        NVF_ERROR(
            params->n_grouped_rows > 1,
            "n_grouped_rows must be greater than 1 for TIDy warp specialization");
      } else {
        bdimx += kWarpSpecializationPaddedThreads;
      }
      int64_t total_threads = bdimx * bdimy * bdimz;
      if (total_threads > 256) {
        int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads);
        int64_t computation_threads =
            total_threads - kWarpSpecializationPaddedThreads;
        ws.num_registers = scheduler_utils::getRegisterSharing(
            reg_per_thread,
            computation_threads,
            kWarpSpecializationPaddedThreads);
      }
      CircularBufferOptions circular_buffer_options{
          .type = ws, .stage = n_stages, .prefetch = n_stages - 1};
      params->circular_buffer_options = circular_buffer_options;
      // Set launch parameters
      params->lparams = LaunchParams(
          gdimx,
          LaunchParams::UNINITIALIZED_VAL,
          LaunchParams::UNINITIALIZED_VAL,
          bdimx,
          bdimy,
          bdimz);
    }
    Launch Parameter Validation

    The warp specialized scheduler has multiple NVF_ERROR checks for launch parameters (lines 108-111). These constraints should be documented and validated against the intended use cases to ensure they don't unnecessarily restrict functionality or cause runtime failures.

    NVF_ERROR(bdimx == 128, "bdimx must be 128 for TIDy warp specialization");
    NVF_ERROR(
        params->n_grouped_rows > 1,
        "n_grouped_rows must be greater than 1 for TIDy warp specialization");
    Refactored Function Interfaces

    The code has been significantly refactored into multiple helper functions (setupPersistentSchedule, applyVectorization, etc.). The interfaces and data flow between these functions should be reviewed to ensure no functionality was lost in the refactoring and that the separation of concerns improves maintainability.

    // Helper struct to hold results from common schedule setup
    struct ScheduleSetupResult {
      TensorView* reduction_tv;
      std::vector<TensorView*> reduction_tvs;
      std::vector<TensorView*> dummy_outputs;
      std::vector<TensorView*> ldg_tvs;
      std::vector<TensorView*> tma_tvs;
      std::vector<TensorView*> smem2reg_tvs;
      std::vector<std::pair<TensorView*, int64_t>> cached_outputs;
    };
    
    // Common setup for persistent schedule: projects buffers, caches
    // inputs/outputs, and categorizes them into TMA loads, LDG loads, and
    // smem-to-reg caches
    ScheduleSetupResult setupPersistentSchedule(
        Fusion* fusion,
        const InnerNormTmaParams* params) {
      ScheduleSetupResult result;
    
      // Use the reduction tensor as the starting point for scheduling
      result.reduction_tvs = scheduler_utils::getReductionTvs(fusion);
      result.reduction_tv = result.reduction_tvs.at(0);
    
      const scheduler_utils::PersistentBufferInfo persistent_info =
          scheduler_utils::persistentBuffers(fusion);
      if (params->project_persistent_buffers) {
        result.dummy_outputs = reduction_scheduler_utils::projectPersistentBuffers(
            fusion, persistent_info, params->project_persistent_buffers);
        for (auto output : result.dummy_outputs) {
          fusion->addOutput(output);
        }
      } else {
        NVF_ERROR(
            false, "Non-projectable buffers are not supported in TMA version yet");
      }
    
      // Identify persistent inputs that will be cached in shared memory using
      // TMA load (CpAsyncBulk) for efficient async memory transfers
      std::unordered_set<TensorView*> persistent_inputs{
          persistent_info.projectable_buffer_inputs.begin(),
          persistent_info.projectable_buffer_inputs.end()};
      for (auto buffer : persistent_info.persistent_buffers) {
        if (buffer->isFusionInput()) {
          persistent_inputs.insert(buffer);
        }
      }
    
      // Categorize cached inputs into TMA loads and regular LDG loads
      // - tma_tvs: Persistent inputs loaded via TMA (CpAsyncBulk) to shared memory
      // - ldg_tvs: Non-persistent inputs loaded via regular LDG instructions
      // - smem2reg_tvs: Intermediate register caches for vectorized smem->reg loads
      const auto& cached_inputs = scheduler_utils::cacheInputs(fusion, true);
      for (auto [tv, input_idx] : cached_inputs) {
        auto input = fusion->inputs().at(input_idx)->as<TensorView>();
        if (!params->tma_load_non_persistent_buffers &&
            !persistent_inputs.contains(input)) {
          // Non-persistent input: use regular global load
          result.ldg_tvs.push_back(tv);
          continue;
        }
        if (auto load_op = dynamic_cast<LoadStoreOp*>(tv->definition())) {
          // Persistent input: use TMA load to shared memory
          load_op->setOpType(LoadStoreOpType::CpAsyncBulk);
          tv->setMemoryType(MemoryType::Shared);
          result.tma_tvs.push_back(tv);
          if (!params->vectorize_load_smem_to_regs) {
            continue;
          }
          // Create register cache for vectorized smem->reg loads
          auto regs_cache = tv->cacheAfter();
          result.smem2reg_tvs.push_back(regs_cache);
          // recompute cached_tv for each consumer, so it is no longer
          // persistent similar to project to inputs, here we are projecting to
          // the shared memory buffer.
          const auto& consumers = ir_utils::consumerTvsOf(regs_cache);
          for (auto consumer : consumers | std::views::drop(1)) {
            auto cached_tv_replicate = RecomputeTv::recompute(regs_cache, {tv});
            ir_utils::replaceValInExprInputs(
                consumer->definition(), regs_cache, cached_tv_replicate);
            result.smem2reg_tvs.push_back(cached_tv_replicate);
          }
        } else {
          result.ldg_tvs.push_back(tv);
        }
      }
    
      // Cache outputs in registers to enable vectorized writes to global memory
      const auto& cached_outputs =
          scheduler_utils::cacheAndForkOutputs(fusion, /*unroll=*/true);
      result.cached_outputs.assign(cached_outputs.begin(), cached_outputs.end());
    
      return result;
    }
    
    // Helper to find vectorization position (one past TIDx axis)
    std::optional<int64_t> getVectorizationPos(TensorView* tv) {
      auto it = std::find_if(
          tv->domain()->loop().begin(),
          tv->domain()->loop().end(),
          [](const IterDomain* id) {
            return id->getParallelType() == ParallelType::TIDx;
          });
      if (it == tv->domain()->loop().end()) {
        return std::nullopt;
      }
      return (int64_t)std::distance(tv->domain()->loop().begin(), it) + 1;
    }
    
    // Apply vectorization to LDG loads, outputs, and smem-to-reg loads
    void applyVectorization(
        Fusion* fusion,
        const InnerNormTmaParams* params,
        const ScheduleSetupResult& setup) {
      // Apply vectorization to non-TMA global loads
      for (auto tv : setup.ldg_tvs) {
        auto vect_pos = getVectorizationPos(tv);
        if (vect_pos.has_value()) {
          tv->axis(vect_pos.value())->parallelize(ParallelType::Vectorize);
        }
      }
    
      // Apply vectorization to global output writes for better memory bandwidth
      for (auto [_, output_idx] : setup.cached_outputs) {
        auto output = fusion->outputs()[output_idx]->as<TensorView>();
        auto vect_pos = getVectorizationPos(output);
        if (vect_pos.has_value()) {
          output->axis(vect_pos.value())->parallelize(ParallelType::Vectorize);
        }
      }
    
      // Apply vectorization to shared memory to register loads
      if (params->vectorize_load_smem_to_regs) {
        for (auto tv : setup.smem2reg_tvs) {
          auto vect_pos = getVectorizationPos(tv);
          if (vect_pos.has_value()) {
            tv->axis(vect_pos.value())->parallelize(ParallelType::Vectorize);
          }
        }
      }
    }
    
    // Schedule inner persistent kernel for multi-wave execution (no warp
    // specialization)
    void scheduleInnerPersistentMultiwave(
        Fusion* fusion,
        const InnerNormTmaParams* params) {
      FusionGuard fg(fusion);
    
      auto setup = setupPersistentSchedule(fusion, params);
      TensorView* reduction_tv = setup.reduction_tv;
      const auto& reduction_tvs = setup.reduction_tvs;
      const auto& tma_tvs = setup.tma_tvs;
    
      // Schedule TMA loads with shape [I, R] where:
      //   I = iteration dimension (batch dimension)
      //   R = reduction dimension
      // Parallelization strategy:
      //   - axis(0): BIDx - each block handles one or more batch elements
      //   - axis(1): Bulk - TMA asynchronously copies entire reduction dimension
      int64_t iteration_pos = 0, reduction_pos = 1;
      if (params->n_grouped_rows > 1) {
        // [I, R] -> [I/TIDy, TIDy, R]
        reduction_tv->split(iteration_pos, params->n_grouped_rows);
        reduction_tv->axis(iteration_pos)->parallelize(ParallelType::BIDx);
        reduction_tv->axis(iteration_pos + 1)->parallelize(ParallelType::TIDy);
        reduction_pos = iteration_pos + 2;
      } else {
        reduction_tv->axis(iteration_pos)->parallelize(ParallelType::BIDx);
      }
    
      TransformPropagator propagator(reduction_tv);
      MaxLogicalDomainInfoSpanningTree(reduction_tv).traverse(&propagator);
      reduction_tv->axis(reduction_pos)->parallelize(ParallelType::Bulk);
      scheduler_utils::parallelizeAllLike(reduction_tv, tma_tvs);
      // Reset reduction_tv's reduction axis back to Serial (only TMA loads use
      // Bulk)
      reduction_tv->axis(reduction_pos)->parallelize(ParallelType::Serial);
    
      // Transform reduction domain for efficient computation:
      //   [I, R] -> [I, b, us, x, v]
      // Where:
      //   - v: vectorization factor (elements per vector instruction)
      //   - x: thread dimension (TIDx, threads cooperating on reduction)
      //   - b: persistent batch size (register persistent buffer size)
      //   - us: unswitch dimension (loop optimization, reduces control flow
      //   overhead)
      reduction_tv->split(reduction_pos, params->vectorization_factor);
      reduction_tv->split(reduction_pos, params->persistent_batch_size, false);
      reduction_tv->split(reduction_pos, 1);
      reduction_tv->axis(reduction_pos + 1)->parallelize(ParallelType::Unswitch);
      reduction_tv->axis(reduction_pos + 2)->parallelize(ParallelType::TIDx);
      reduction_tv->axis(reduction_pos + 2)->padToMultipleOfWarp();
    
      // Create rfactor tensor to separate thread-local reduction from block
      // reduction This enables a two-stage reduction:
      //   1. Thread-local vectorized reduction (across b and v dimensions)
      //   2. Block-level reduction (across x dimension using warp/block primitives)
      // rfactor axes: {reduction_pos, vectorize_pos} corresponding to b and v
      // dimensions
      int64_t vectorize_pos = reduction_pos + 3;
      auto reference_tv = reduction_tv->rFactor({reduction_pos, vectorize_pos});
    
      // Propagate transformations from reference_tv to all non-TMA tensors
      // TMA tensors keep their simple [BIDx, Bulk] schedule
      std::vector<TensorView*> non_tma_tvs =
          ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
      TransformPropagator non_tma_propagator(reference_tv);
      SetSelector selector({non_tma_tvs.begin(), non_tma_tvs.end()});
      MaxLogicalDomainInfoSpanningTree(reference_tv, &selector)
          .traverse(&non_tma_propagator);
    
      // If reduction_tv is rfactored, rfactor all reductions.
      // Also needs to update non_tma_tvs to include newly rfactored tvs.
      if (reference_tv != reduction_tv) {
        reduction_scheduler_utils::propagateRFactor(
            reference_tv, reduction_tv, reduction_tvs);
        non_tma_tvs =
            ir_utils::allTvsExcept(fusion, {tma_tvs.begin(), tma_tvs.end()});
      }
      scheduler_utils::parallelizeAllLike(reference_tv, non_tma_tvs);
    
      // Apply vectorization
      applyVectorization(fusion, params, setup);
    
      // Remove dummy outputs that were used for persistent buffer projection
      for (auto output : setup.dummy_outputs) {
        fusion->removeOutput(output);
      }
    
      // Apply aggressive inlining to reduce register pressure and improve locality
      // Exclude ldg_tvs if pre-loading is enabled to control issue order
      std::unordered_set<TensorView*> exclude_tvs;
      if (params->pre_load_ldg_tvs) {
        exclude_tvs.insert(setup.ldg_tvs.begin(), setup.ldg_tvs.end());
      }
      std::vector<TensorView*> inline_most_tvs =
          ir_utils::allTvsExcept(fusion, exclude_tvs);
      inlineMost(inline_most_tvs);
    
      // Refine cache policies for optimal memory hierarchy usage
      refineCachePolicy(fusion);
    }

    Base automatically changed from llu/norm_auto_multi_wave to main December 16, 2025 01:43
    @liqiangxl liqiangxl marked this pull request as ready for review December 16, 2025 13:50
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 16, 2025

    Greptile Summary

    This PR implements warp specialized inner persistent scheduling for TMA-based normalization kernels. The changes refactor the scheduler to support circular buffering with specialized compute warp groups, splitting the original scheduleInnerPersistent function into separate multiwave and warp-specialized paths.

    Major changes:

    • Added warp specialization support with TIDy-based compute warp groups (2 groups with 2 rows each)
    • Introduced setupPersistentSchedule helper to extract common setup logic for buffer projection and input/output caching
    • Created scheduleInnerPersistentMultiwave for non-warp-specialized scheduling (original behavior)
    • Created scheduleInnerPersistentWarpSpecialized for warp-specialized scheduling with circular buffering
    • Extracted applyVectorization helper for consistent vectorization across both scheduling paths
    • Renamed rows_per_block to n_grouped_rows for clarity
    • Updated test to use concrete tensor shapes

    Issues identified in previous threads:

    • Register sharing calculation fragility (lines 115-123): Uses subtraction approach that differs from reference implementation
    • Error message mismatch (line 110): Condition checks > 1 but prior comments suggested message was incorrect
    • Warp specialization condition (line 96): Hard-coded check for bdimx == 128 may be too restrictive

    Confidence Score: 3/5

    • This PR is safe to merge with moderate risk due to register sharing calculation fragility
    • The core refactoring is well-structured and test coverage exists, but the register sharing calculation (lines 115-123) uses a subtraction-based approach that differs from the reference implementation in normalization_inner_outer_tma_ws.cpp:286-292. While it happens to produce correct results for the specific case of bdimx=128 and bdimy=2, the logic is fundamentally fragile and may break if thread dimensions change. The previous review threads have thoroughly documented these issues. The refactoring itself is clean and improves code organization.
    • csrc/scheduler/normalization_inner_tma.cpp needs attention for register sharing calculation robustness

    Important Files Changed

    Filename Overview
    csrc/scheduler/normalization_inner_tma.cpp Refactored to support warp specialized inner persistent scheduling with TMA. Register sharing calculation on line 119 has potential fragility issues already documented in previous threads.
    csrc/scheduler/normalization_inner_tma.h Simple rename from rows_per_block to n_grouped_rows for clarity. No functional changes or issues.
    tests/cpp/test_persistent_buffer.cpp Test updated to use concrete tensor shapes and improved comment clarity. Good test coverage for both small and large batch sizes.

    Sequence Diagram

    sequenceDiagram
        participant User
        participant Heuristics as getInnerPersistentHeuristics
        participant Dispatcher as scheduleInnerPersistent
        participant Setup as setupPersistentSchedule
        participant Multiwave as scheduleInnerPersistentMultiwave
        participant WarpSpec as scheduleInnerPersistentWarpSpecialized
        participant Apply as applyVectorization
    
        User->>Heuristics: Request scheduler params
        
        Note over Heuristics: Calculate bdimx, persistent_batch_size
        Heuristics->>Heuristics: Compute n_stages (iter & smem limited)
        
        alt n_stages >= 2 && bdimx == 128
            Note over Heuristics: Enable Warp Specialization
            Heuristics->>Heuristics: Set n_grouped_rows = 2
            Heuristics->>Heuristics: Set bdimy = n_compute_warp_groups (2)
            
            alt bdimy > 1 (TIDy specialization)
                Heuristics->>Heuristics: bdimy += 1 (add TMA warp)
                Heuristics->>Heuristics: Calculate register sharing
            else TIDx specialization
                Heuristics->>Heuristics: bdimx += kWarpSpecializationPaddedThreads
            end
            
            Heuristics->>Heuristics: Set circular_buffer_options
        end
        
        Heuristics-->>User: Return InnerNormTmaParams
        
        User->>Dispatcher: scheduleInnerPersistent(fusion, params)
        
        alt circular_buffer_options enabled
            Dispatcher->>WarpSpec: scheduleInnerPersistentWarpSpecialized
            WarpSpec->>Setup: setupPersistentSchedule
            Setup->>Setup: Project persistent buffers
            Setup->>Setup: Cache inputs (TMA & LDG)
            Setup->>Setup: Cache outputs
            Setup-->>WarpSpec: Return ScheduleSetupResult
            
            WarpSpec->>WarpSpec: Split iteration domain (Group, TIDy, BIDx)
            WarpSpec->>WarpSpec: Transform reduction domain (b, us, x, v)
            WarpSpec->>WarpSpec: rFactor for 2-stage reduction
            WarpSpec->>WarpSpec: Propagate transformations
            WarpSpec->>WarpSpec: Apply Group parallelization
            WarpSpec->>Apply: applyVectorization
            WarpSpec->>WarpSpec: Apply circular buffering to TMA loads
            
        else multiwave mode
            Dispatcher->>Multiwave: scheduleInnerPersistentMultiwave
            Multiwave->>Setup: setupPersistentSchedule
            Setup-->>Multiwave: Return ScheduleSetupResult
            
            Multiwave->>Multiwave: Schedule TMA loads (BIDx, Bulk)
            Multiwave->>Multiwave: Transform reduction domain
            Multiwave->>Multiwave: rFactor for 2-stage reduction
            Multiwave->>Apply: applyVectorization
        end
        
        Apply->>Apply: Vectorize LDG loads
        Apply->>Apply: Vectorize output writes
        Apply->>Apply: Vectorize smem-to-reg loads
        
        Dispatcher-->>User: Scheduled fusion ready
    
    Loading

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Additional Comments (1)

    1. csrc/scheduler/normalization_inner_tma.cpp, line 23-29 (link)

      syntax: Duplicate includes - lines 23-25 are repeated at lines 27-29.

    6 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    6 files reviewed, no comments

    Edit Code Review Agent Settings | Greptile

    @liqiangxl liqiangxl requested a review from rdspring1 December 16, 2025 20:26
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 2 comments

    Edit Code Review Agent Settings | Greptile

    Comment on lines +115 to +123
    int64_t total_threads = bdimx * bdimy * bdimz;
    if (total_threads > 256) {
    int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads);
    int64_t computation_threads =
    total_threads - kWarpSpecializationPaddedThreads;
    ws.num_registers = scheduler_utils::getRegisterSharing(
    reg_per_thread,
    computation_threads,
    kWarpSpecializationPaddedThreads);
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: Register sharing calculation includes TMA warp in total_threads (e.g., 128 * 3 * 1 = 384), but computation_threads should exclude the full TMA warp dimension. For TIDy specialization, this should be bdimx * (bdimy - 1) = 128 * 2 = 256 computation threads, matching the pattern in normalization_inner_outer_tma_ws.cpp:286-292.

    Suggested change
    int64_t total_threads = bdimx * bdimy * bdimz;
    if (total_threads > 256) {
    int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads);
    int64_t computation_threads =
    total_threads - kWarpSpecializationPaddedThreads;
    ws.num_registers = scheduler_utils::getRegisterSharing(
    reg_per_thread,
    computation_threads,
    kWarpSpecializationPaddedThreads);
    int64_t total_threads = bdimx * bdimy * bdimz;
    if (total_threads > 256) {
    int64_t reg_per_thread = getRegPerThreadGivenThreadsPerSM(total_threads);
    int64_t computation_threads = bdimx * (bdimy - 1);
    ws.num_registers = scheduler_utils::getRegisterSharing(
    reg_per_thread,
    computation_threads,
    kWarpSpecializationPaddedThreads);
    }

    Comment on lines +96 to +98
    if (n_stages >= 2 && bdimx == 128) {
    gdimx = sm_count;
    bdimx = 128; // 4 warps per warp group
    Copy link
    Contributor

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    logic: The condition checks bdimx == 128 before entering the block, but bdimx was set on line 69 to std::min(4 * warp_size, after_vect) which may not be 128. Then line 98 reassigns it to 128. Either the condition should be bdimx >= 128 or should check a different variable. Is the intent to only enable warp specialization when the initial bdimx heuristic is exactly 128, or should it work for other values?

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    const int64_t n_stages = std::min(smem_limited_stages, iter_limited_stages);
    if (n_stages >= 2 && bdimx == 128) {
    gdimx = sm_count;
    bdimx = 128; // 4 warps per warp group
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    The if condition already checks that bdimx == 128, so this is redundant.

    Copy link
    Collaborator

    @rdspring1 rdspring1 left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    void scheduleInnerPersistent(Fusion* fusion, const InnerNormTmaParams* params) is 290 lines.

    Is it necessary to add warp specialization to the same function?

    There are several if-statements for each schedule decision. It is difficult to reason about the schedule holistically.

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Copy link
    Contributor

    @greptile-apps greptile-apps bot left a comment

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    3 files reviewed, 1 comment

    Edit Code Review Agent Settings | Greptile

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @greptile-apps
    Copy link
    Contributor

    greptile-apps bot commented Dec 20, 2025

    Greptile's behavior is changing!

    From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

    This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    3 participants