Skip to content

Commit fd5d5cb

Browse files
committed
fix
Signed-off-by: junq <[email protected]>
2 parents e20fbab + e35fca4 commit fd5d5cb

File tree

97 files changed

+3807
-2631
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

97 files changed

+3807
-2631
lines changed

cpp/tensorrt_llm/common/envUtils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,12 @@ bool getEnvForceDeterministicMOE()
366366
return forceDeterministic;
367367
}
368368

369+
bool getEnvMOEDisableFinalizeFusion()
370+
{
371+
static bool const moeDisableFinalizeFusion = getBoolEnv("TRTLLM_MOE_DISABLE_FINALIZE_FUSION");
372+
return moeDisableFinalizeFusion;
373+
}
374+
369375
bool getEnvForceDeterministicAttention()
370376
{
371377
static bool const forceDeterministic

cpp/tensorrt_llm/common/envUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ bool getEnvForceDeterministic();
8686
// Force deterministic behavior for MoE plugin.
8787
bool getEnvForceDeterministicMOE();
8888

89+
// Disable finalize fusion in MoE plugin
90+
bool getEnvMOEDisableFinalizeFusion();
91+
8992
// Force deterministic behavior for attention plugin.
9093
bool getEnvForceDeterministicAttention();
9194

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp

Lines changed: 0 additions & 568 deletions
This file was deleted.

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/fusion/sm90_visitor_scatter.hpp

Lines changed: 547 additions & 0 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm_configs.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,11 +133,6 @@ enum class CutlassTileConfigSM100
133133
CtaShape128x256x128B,
134134
CtaShape128x128x256B,
135135
CtaShape128x256x256B,
136-
137-
// M=256
138-
CtaShape256x64x128B,
139-
CtaShape256x128x128B,
140-
CtaShape256x256x128B,
141136
};
142137

143138
enum class CutlassTileConfigSM120

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/util/gather_tensor.hpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#include "cute/tensor.hpp"
2020
#include "cute/util/print.hpp"
2121

22-
namespace tensorrt_llm::cutlass_extensions
22+
namespace cutlass::util
2323
{
2424

2525
/// Function object that applies an index to its argument
@@ -81,7 +81,7 @@ struct CustomStride
8181
template <class Div>
8282
CUTE_HOST_DEVICE constexpr friend auto safe_div(CustomStride const& s, Div const& div)
8383
{
84-
return CustomStride<Func, decltype(safe_div(s.stride_, div))>(s.func_, safe_div(s.stride_, div));
84+
return CustomStride<Func, decltype(cute::safe_div(s.stride_, div))>(s.func_, cute::safe_div(s.stride_, div));
8585
}
8686

8787
// Circumvent the requirement on make_layout that shape and stride are integral
@@ -116,7 +116,7 @@ CUTLASS_HOST_DEVICE auto make_gather_tensor(Iterator iter, Shape const& shape, S
116116
Layout gather_layout = make_custom_stride_layout(stride, static_cast<Func&&>(func));
117117
return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout});
118118
}
119-
} // namespace tensorrt_llm::cutlass_extensions
119+
} // namespace cutlass::util
120120

121121
namespace cute
122122
{

cpp/tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.cpp

Lines changed: 41 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -377,72 +377,62 @@ std::vector<CutlassGemmConfig> get_candidate_configs_sm100(CutlassGemmConfig::Ca
377377
if (config & CutlassGemmConfig::GROUPED_GEMM)
378378
{
379379
std::vector<CutlassGemmConfig> candidate_configs;
380-
if ((config & CutlassGemmConfig::FP4_ONLY) != 0)
380+
if (config & CutlassGemmConfig::FP4_ONLY)
381381
{
382382
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
383383
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
384-
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x128x128B,
384+
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
385385
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
386+
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x128x128B,
387+
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
386388
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
387389
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
388-
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x256x128B,
389-
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
390390
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x256x128B,
391-
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x2x1});
392-
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape256x64x128B,
391+
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
392+
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
393393
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_2x1x1});
394394
candidate_configs.push_back(CutlassGemmConfig{CutlassTileConfigSM100::CtaShape128x64x128B,
395395
MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1});
396396
return candidate_configs;
397397
}
398398

399-
for (int cluster_m = 1; cluster_m <= 2; cluster_m++)
399+
std::vector<std::pair<CutlassTileConfigSM100, ClusterShape>> tile_configs{
400+
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x1x1},
401+
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_1x1x1},
402+
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x2x1},
403+
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_1x2x1},
404+
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x1x1},
405+
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_2x1x1},
406+
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x2x1},
407+
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_2x2x1},
408+
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_2x1x1},
409+
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x1x1},
410+
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x1x1},
411+
{CutlassTileConfigSM100::CtaShape128x256x128B, ClusterShape::ClusterShape_2x1x1},
412+
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_2x2x1},
413+
{CutlassTileConfigSM100::CtaShape128x128x128B, ClusterShape::ClusterShape_2x2x1},
414+
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x1x1},
415+
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x1x1},
416+
{CutlassTileConfigSM100::CtaShape64x32x128B, ClusterShape::ClusterShape_1x2x1},
417+
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x1x1},
418+
{CutlassTileConfigSM100::CtaShape64x64x128B, ClusterShape::ClusterShape_1x2x1},
419+
{CutlassTileConfigSM100::CtaShape64x256x128B, ClusterShape::ClusterShape_1x1x1},
420+
{CutlassTileConfigSM100::CtaShape64x128x128B, ClusterShape::ClusterShape_1x2x1},
421+
{CutlassTileConfigSM100::CtaShape128x64x128B, ClusterShape::ClusterShape_1x1x1},
422+
{CutlassTileConfigSM100::CtaShape128x32x128B, ClusterShape::ClusterShape_1x2x1},
423+
};
424+
425+
if (config & CutlassGemmConfig::FP8_ONLY)
400426
{
401-
bool Is2SM = cluster_m == 2;
402-
for (int cluster_n = 1; cluster_n <= 2; cluster_n++)
403-
{
404-
std::vector base = {// M=128
405-
CutlassTileConfigSM100::CtaShape128x128x128B, CutlassTileConfigSM100::CtaShape128x256x128B};
406-
407-
if (Is2SM)
408-
{
409-
if (cluster_n == 1)
410-
{
411-
base.push_back(CutlassTileConfigSM100::CtaShape128x64x128B);
412-
base.push_back(CutlassTileConfigSM100::CtaShape256x64x128B);
413-
}
414-
415-
std::vector twosm = {// M=256
416-
CutlassTileConfigSM100::CtaShape256x128x128B, CutlassTileConfigSM100::CtaShape256x256x128B};
417-
std::copy(twosm.begin(), twosm.end(), std::back_inserter(base));
418-
}
419-
else
420-
{
421-
if (cluster_n == 1)
422-
{
423-
base.push_back(CutlassTileConfigSM100::CtaShape128x32x128B);
424-
if ((config & CutlassGemmConfig::FP8_ONLY) != 0)
425-
{
426-
base.push_back(CutlassTileConfigSM100::CtaShape128x16x128B);
427-
}
428-
}
429-
430-
std::vector onesm{CutlassTileConfigSM100::CtaShape64x64x128B,
431-
CutlassTileConfigSM100::CtaShape64x128x128B, CutlassTileConfigSM100::CtaShape64x256x128B,
432-
CutlassTileConfigSM100::CtaShape128x64x128B};
433-
std::copy(onesm.begin(), onesm.end(), std::back_inserter(base));
434-
}
427+
tile_configs.push_back({CutlassTileConfigSM100::CtaShape128x16x128B, ClusterShape::ClusterShape_1x1x1});
428+
// TODO: re-enable when handled by the MoE GEMM dispatch
429+
// tile_configs.push_back({ CutlassTileConfigSM100::CtaShape128x8x256B, ClusterShape::ClusterShape_1x1x1 });
430+
}
435431

436-
constexpr std::array cluster_shapes
437-
= {std::array{ClusterShape::ClusterShape_1x1x1, ClusterShape::ClusterShape_1x2x1},
438-
std::array{ClusterShape::ClusterShape_2x1x1, ClusterShape::ClusterShape_2x2x1}};
439-
auto cluster = cluster_shapes[cluster_m - 1][cluster_n - 1];
440-
for (auto tile : base)
441-
{
442-
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
443-
candidate_configs.push_back(config);
444-
}
445-
}
432+
for (auto [tile, cluster] : tile_configs)
433+
{
434+
CutlassGemmConfig config{tile, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, cluster};
435+
candidate_configs.push_back(config);
446436
}
447437
return candidate_configs;
448438
}

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,6 @@
3737

3838
namespace tensorrt_llm::kernels::cutlass_kernels
3939
{
40-
template <class T>
41-
constexpr auto transpose_stride(T const& t)
42-
{
43-
return cute::prepend(cute::prepend(cute::take<2, cute::rank_v<T>>(t), cute::get<0>(t)), cute::get<1>(t));
44-
}
4540

4641
template <typename AType, typename BType, typename BScaleType, typename OType>
4742
struct GroupedGemmInput
@@ -72,8 +67,6 @@ struct GroupedGemmInput
7267

7368
struct TmaWarpSpecializedGroupedGemmInput
7469
{
75-
template <class T>
76-
using TransposeStride = decltype(transpose_stride<T>(T{}));
7770
template <class Tag>
7871
using TransposeLayoutTag = std::conditional_t<std::is_same_v<Tag, cutlass::layout::RowMajor>,
7972
cutlass::layout::ColumnMajor, cutlass::layout::RowMajor>;
@@ -86,6 +79,7 @@ struct TmaWarpSpecializedGroupedGemmInput
8679
using LayoutA = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for A matrix operand
8780
using LayoutB = TransposeLayoutTag<cutlass::layout::ColumnMajor>; // Layout type for B matrix operand
8881
using LayoutC = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for C matrix operand
82+
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
8983

9084
constexpr static int NVFP4BlockScaleVectorSize = 16;
9185
constexpr static int MXFPXBlockScaleVectorSize = 32;
@@ -121,6 +115,7 @@ struct TmaWarpSpecializedGroupedGemmInput
121115
using StrideB
122116
= std::remove_pointer_t<cutlass::detail::TagToStrideA_t<LayoutB*>>; // Use A because they will be swapped
123117
using StrideC = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutC*>>;
118+
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
124119

125120
#ifdef ENABLE_FP8
126121
template <class T>
@@ -147,37 +142,26 @@ struct TmaWarpSpecializedGroupedGemmInput
147142
StrideC* stride_c = nullptr;
148143
void const** ptr_c = nullptr;
149144

150-
struct DefaultEpilogue
151-
{
152-
using LayoutD = TransposeLayoutTag<cutlass::layout::RowMajor>; // Layout type for D matrix operand
153-
using StrideD = std::remove_pointer_t<cutlass::detail::TagToStrideC_t<LayoutD*>>;
154-
155-
StrideD* stride_d = nullptr;
156-
void** ptr_d = nullptr;
157-
};
145+
// D is used in all cases except fused finalize
146+
StrideD* stride_d = nullptr;
147+
void** ptr_d = nullptr;
158148

159149
struct FusedFinalizeEpilogue
160150
{
161-
using StrideFinalOutput = DefaultEpilogue::StrideD;
162-
using StrideBias = TransposeStride<cute::Stride<cute::_0, cute::_1, int>>;
163-
using StrideRouterScales = TransposeStride<cute::Stride<cute::_1, cute::_0>>;
151+
using StrideFinalOutput = cutlass::detail::TagToStrideC_t<LayoutD>;
164152

165153
void* ptr_final_output = nullptr;
166154
StrideFinalOutput stride_final_output{};
167155

168-
void const* ptr_bias = nullptr;
169-
StrideBias stride_bias{};
170-
171-
float const* ptr_router_scales = nullptr;
172-
StrideRouterScales stride_router_scales{};
156+
void const** ptr_bias = nullptr;
157+
float const** ptr_router_scales = nullptr;
173158

174-
int64_t const* ptr_expert_first_token_offset = nullptr;
175-
int const* ptr_source_token_index = nullptr;
159+
int const** ptr_source_token_index = nullptr;
160+
int num_rows_in_final_output = 0;
176161

177-
size_t num_rows_in_final_output = 0;
162+
bool use_reduction = true;
178163
};
179164

180-
DefaultEpilogue default_epilogue;
181165
FusedFinalizeEpilogue fused_finalize_epilogue;
182166

183167
enum class EpilogueFusion
@@ -235,7 +219,7 @@ struct TmaWarpSpecializedGroupedGemmInput
235219
uint8_t* gemm_workspace = nullptr;
236220
size_t gemm_workspace_size = 0;
237221

238-
static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
222+
static std::array<size_t, 20> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
239223

240224
static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);
241225

@@ -247,9 +231,7 @@ struct TmaWarpSpecializedGroupedGemmInput
247231
return stride_a != nullptr && ptr_a != nullptr;
248232
}
249233

250-
void setFinalizeFusionParams(void* final_output, float const* router_scales,
251-
int64_t const* expert_first_token_offset, int const* source_token_index, void const* bias, int hidden_size,
252-
int num_output_tokens);
234+
void setFinalizeFusionParams(void* final_output, int hidden_size, int num_output_tokens, bool use_reduction);
253235

254236
std::string toString() const;
255237
};

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,8 @@ class CutlassMoeFCRunnerInterface
495495
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
496496
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
497497
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
498-
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream)
498+
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
499+
int const* permuted_row_to_unpermuted_row, cudaStream_t stream)
499500
= 0;
500501

501502
virtual std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -512,13 +513,13 @@ class CutlassMoeFCRunnerInterface
512513
virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0;
513514

514515
bool is_profiler = false;
515-
bool use_deterministic_hopper_reduce_ = false;
516+
bool use_fused_finalize_ = true;
516517
};
517518

518519
// Assumes inputs activations are row major. Weights need to be preprocessed by th_op/weight_quantize.cc .
519520
// Nested in a class to avoid multiple calls to cudaGetDeviceProperties as this call can be expensive.
520521
// Avoid making several duplicates of this class.
521-
template <typename T, /*The type used for activations*/
522+
template <typename T, /* The type used for activations */
522523
typename WeightType, /* The type for the MoE weights */
523524
typename OutputType = T, /* The type for the MoE final output */
524525
typename InputType = T, /* The type for the MoE input */
@@ -709,7 +710,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
709710
void const* weights1, void const* weights2, float const* alpha_scale_flat1, float const* alpha_scale_flat2,
710711
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
711712
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, void const* bias1,
712-
void const* bias2, void* gemm1_output, void* gemm2_output, cudaStream_t stream) override
713+
void const* bias2, void* gemm1_output, void* gemm2_output, float const* router_scales,
714+
int const* permuted_row_to_unpermuted_row, cudaStream_t stream) override
713715
{
714716
return Self::computeStridesTmaWarpSpecialized(expert_first_token_offset, layout_info1, layout_info2, num_tokens,
715717
expanded_num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts_per_node,
@@ -718,7 +720,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
718720
alpha_scale_flat1, alpha_scale_flat2, fp4_act_flat1, fp4_act_flat2, quant_params,
719721
reinterpret_cast<ScaleBiasType const*>(bias1), reinterpret_cast<ScaleBiasType const*>(bias2),
720722
reinterpret_cast<UnfusedGemmOutputType*>(gemm1_output),
721-
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), stream);
723+
reinterpret_cast<UnfusedGemmOutputType*>(gemm2_output), router_scales, permuted_row_to_unpermuted_row,
724+
stream);
722725
}
723726

724727
std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
@@ -760,7 +763,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
760763
float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1,
761764
TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params,
762765
ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output,
763-
UnfusedGemmOutputType* gemm2_output, cudaStream_t stream);
766+
UnfusedGemmOutputType* gemm2_output, float const* router_scales, int const* permuted_row_to_unpermuted_row,
767+
cudaStream_t stream);
764768
static std::pair<TmaWarpSpecializedGroupedGemmInput, TmaWarpSpecializedGroupedGemmInput>
765769
computeStridesTmaWarpSpecializedLowLatency(TmaWarpSpecializedGroupedGemmInput layout_info1,
766770
TmaWarpSpecializedGroupedGemmInput layout_info2, int64_t num_tokens, int64_t gemm1_n, int64_t gemm1_k,
@@ -790,8 +794,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
790794

791795
bool mayHaveFinalizeFused() const
792796
{
793-
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90
794-
&& !use_deterministic_hopper_reduce_ && !use_w4_groupwise;
797+
return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() >= 90 && use_fused_finalize_
798+
&& !use_w4_groupwise;
795799
}
796800

797801
// TODO: This should eventually take the quant params to give more flexibility

0 commit comments

Comments
 (0)