Skip to content

Commit 6e57bb0

Browse files
authored
Merge branch 'main' into clean_cuda_graph
2 parents 84ad219 + e30d9ac commit 6e57bb0

File tree

22 files changed

+1334
-273
lines changed

22 files changed

+1334
-273
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ TensorRT-LLM
77
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
88
[![python](https://img.shields.io/badge/python-3.12-green)](https://www.python.org/downloads/release/python-3123/)
99
[![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/)
10-
[![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads)
11-
[![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt)
10+
[![cuda](https://img.shields.io/badge/cuda-13.0.0-green)](https://developer.nvidia.com/cuda-downloads)
11+
[![trt](https://img.shields.io/badge/TRT-10.13.2-green)](https://developer.nvidia.com/tensorrt)
1212
[![version](https://img.shields.io/badge/release-1.1.0rc6-green)](./tensorrt_llm/version.py)
1313
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
1414

cpp/tensorrt_llm/deep_ep/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
set(DEEP_EP_COMMIT 515a311f290eb6d9592fcccfcc80c40f5123ca72)
1+
set(DEEP_EP_COMMIT be2582ffe69b5e7d61c3bc9bf7a5316bc48261f9)
22
set(NVSHMEM_URL_HASH
33
SHA256=eb2c8fb3b7084c2db86bd9fd905387909f1dfd483e7b45f7b3c3d5fcf5374b5a)
44

cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface
553553
|| std::is_same_v<T, __nv_fp8_e5m2>) &&!std::is_same_v<WeightType, cutlass::uint4b_t>;
554554
static constexpr bool use_w4afp8
555555
= std::is_same_v<WeightType, cutlass::uint4b_t> && std::is_same_v<T, __nv_fp8_e4m3>;
556+
static constexpr bool use_fp8_input = std::is_same_v<InputType, __nv_fp8_e4m3>;
556557
static_assert(!std::is_same_v<BackBoneType, __nv_fp8_e4m3>, "Current logic requires backbone type to be >=16-bits");
557558
static_assert(!std::is_same_v<OutputType, __nv_fp8_e4m3>, "Current logic requires output type to be >=16-bits");
558559
#else

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,7 +1625,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input,
16251625
else if constexpr (std::is_same_v<ExpandedActivationsType, __nv_fp8_e4m3>
16261626
&& std::is_same_v<InputActivationsType, __nv_fp8_e4m3>)
16271627
{
1628-
TLLM_CHECK_WITH_INFO(!prequant_scales, "NVFP4 is not supported for AWQ");
1628+
TLLM_CHECK_WITH_INFO(!prequant_scales, "FP8 is not supported for AWQ");
16291629
return quant_params.mxfp8_mxfp4.fc1.weight_block_scale
16301630
? &expandInputRowsKernel<InputActivationsType, ExpandedActivationsType,
16311631
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX, false>
@@ -3689,7 +3689,7 @@ void CutlassMoeFCRunner<T, WeightType, OutputType, InputType, BackBoneType, Enab
36893689
permuted_token_final_scales_, permuted_row_to_unpermuted_row_, num_rows, hidden_size, experts_per_token,
36903690
num_experts_per_node, quant_params, use_per_expert_act_scale, expert_first_token_offset_,
36913691
fc1_fp4_act_scale_, input_sf, swizzled_input_sf,
3692-
use_w4afp8 ? quant_params.groupwise.fc1.act_scales : nullptr, stream);
3692+
(use_w4afp8 && !use_fp8_input) ? quant_params.groupwise.fc1.act_scales : nullptr, stream);
36933693
auto const* gemm1_input = gemm1_input_expand;
36943694

36953695
sync_check_cuda_error(stream);
@@ -4755,6 +4755,7 @@ template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, half, half>;
47554755
template class CutlassMoeFCRunner<__nv_fp8_e4m3, __nv_fp8_e4m3, __nv_bfloat16>;
47564756
template class CutlassMoeFCRunner<__nv_bfloat16, __nv_fp8_e4m3, __nv_bfloat16>;
47574757
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_bfloat16>;
4758+
template class CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>;
47584759
#endif
47594760
#endif
47604761
#ifdef ENABLE_FP4

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/KernelRunner.cpp

Lines changed: 45 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@
2121
#include "tensorrt_llm/common/envUtils.h"
2222
#include "trtllmGen_bmm_export/BatchedGemmInterface.h"
2323
#include "trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
24-
// DO NOT include logger.h before BatchedGemmInterface.h as it #undef TLLM_LOG_INFO and co.
24+
// DO NOT include cudaUtils.h and logger.h before BatchedGemmInterface.h as it #undef TLLM_LOG_INFO and co.
25+
#include "tensorrt_llm/common/cudaUtils.h"
2526
#include "tensorrt_llm/common/logger.h"
2627

2728
namespace tensorrt_llm
@@ -306,6 +307,8 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
306307
auto const bmm = BatchedGemmInterface();
307308
auto const configs = bmm.getBatchedGemmConfigs();
308309

310+
int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
311+
309312
BatchedGemmData gemmData;
310313
// Dims
311314
gemmData.mProblemDimensions.mNumBatches = numBatches;
@@ -319,73 +322,68 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
319322
gemmData.mProblemDimensions.mRank = 0;
320323
gemmData.mProblemDimensions.mWorldSize = 1;
321324
gemmData.mProblemDimensions.mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
322-
// Tier 0: K < tileK, prefer higher efficiency.
323-
auto cmpTier0 = [&configs, &gemmData](int64_t idx0, int64_t idx1)
325+
auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1)
324326
{
325327
auto const& optionsA = configs[idx0].mOptions;
326328
auto const& optionsB = configs[idx1].mOptions;
327329
int32_t sizeK = gemmData.mProblemDimensions.mK;
328-
// Both waste computation, prefer higher efficiency.
329-
if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK)
330-
{
331-
double eff_a = (double) sizeK / optionsA.mTileK;
332-
double eff_b = (double) sizeK / optionsB.mTileK;
333-
return eff_a > eff_b;
334-
}
335-
// If either can be utilized, sort by tileK.
336-
else
330+
331+
// Tier 0: K < tileK, prefer higher efficiency.
332+
if (optionsA.mTileK != optionsB.mTileK)
337333
{
338-
return optionsA.mTileK > optionsB.mTileK;
334+
// Both waste computation, prefer higher efficiency.
335+
if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK)
336+
{
337+
double eff_a = (double) sizeK / optionsA.mTileK;
338+
double eff_b = (double) sizeK / optionsB.mTileK;
339+
return eff_a > eff_b;
340+
}
341+
// If either can be utilized, sort by tileK.
342+
else
343+
{
344+
return optionsA.mTileK > optionsB.mTileK;
345+
}
339346
}
340-
};
341-
// Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
342-
auto cmpTier1 = [&configs](int64_t idx0, int64_t idx1)
343-
{
344-
auto const& optionsA = configs[idx0].mOptions;
345-
auto const& optionsB = configs[idx1].mOptions;
346-
if (optionsA.mTileK == optionsB.mTileK)
347+
348+
// Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
349+
if (optionsA.mUseUnrollLoop2xForMma != optionsB.mUseUnrollLoop2xForMma)
347350
{
348351
return optionsA.mUseUnrollLoop2xForMma;
349352
}
350-
return false;
351-
};
352-
// Tier 2+: When previous comparators are the same, prefer higher tileM.
353-
auto cmpTier2 = [&configs](int64_t idx0, int64_t idx1)
354-
{
355-
auto const& optionsA = configs[idx0].mOptions;
356-
auto const& optionsB = configs[idx1].mOptions;
357-
if (optionsA.mTileK == optionsB.mTileK && optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma)
353+
354+
// Tier 2+: When previous comparators are the same, prefer higher tileM.
355+
if (optionsA.mTileM != optionsB.mTileM)
358356
{
359357
return optionsA.mTileM > optionsB.mTileM;
360358
}
361-
return false;
362-
};
363-
// Tier 2+: When previous comparators are the same, and when number of estimated CTAs is on the larger side, prefer
364-
// persistent tile scheduler. The threshold is hardcoded as >148 CTAs at the moment.
365-
auto cmpTier3 = [&configs, &gemmData](int64_t idx0, int64_t idx1)
366-
{
367-
int32_t sizeM = gemmData.mProblemDimensions.mM;
368-
int32_t sizeN = gemmData.mProblemDimensions.mN;
369-
auto const& optionsA = configs[idx0].mOptions;
370-
auto const& optionsB = configs[idx1].mOptions;
371-
if (optionsA.mTileK == optionsB.mTileK && optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma
372-
&& optionsA.mTileM == optionsB.mTileM)
359+
360+
// Tier 2+: When previous comparators are the same, prefer higher tileN.
361+
if (optionsA.mTileN != optionsB.mTileN)
362+
{
363+
return optionsA.mTileN > optionsB.mTileN;
364+
}
365+
366+
// Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on the larger side,
367+
// prefer persistent tile scheduler.
368+
if (optionsA.mTileScheduler != optionsB.mTileScheduler)
373369
{
374-
int64_t numTilesM = divUp(sizeM, optionsA.mTileM);
375-
int64_t numTilesN = divUp(sizeN, optionsA.mTileN);
376-
if (numTilesM * numTilesN > 148)
370+
auto options = bmm.getOptionsFromConfigAndData(configs[idx0], gemmData);
371+
auto numCtas = bmm.getNumCtas(options, gemmData.mProblemDimensions.mMaxNumCtasInTokenDim);
372+
if (numCtas > multiProcessorCount)
377373
{
378374
return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
379375
}
376+
else
377+
{
378+
return optionsB.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
379+
}
380380
}
381+
381382
return false;
382383
};
383384
// Sort configs by options.
384385
std::vector<int64_t> sortedIndices = mPassingConfigIndices;
385-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier0);
386-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier1);
387-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier2);
388-
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpTier3);
386+
std::sort(sortedIndices.begin(), sortedIndices.end(), cmpFunc);
389387

390388
// Special rules for corner cases, if applicable.
391389
std::vector<int64_t> prioritizedIndices = prioritizePredefinedConfigs(m, n, k, sortedIndices, configs);

cpp/tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/BatchedGemmInterface.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -524,12 +524,13 @@ class BatchedGemmInterface
524524
// Returns true if the configuration of the cubin can be executed for the given params.
525525
bool isValidConfig(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
526526

527+
// Creates GemmOptions from kernel and data.
528+
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
529+
527530
private:
528531
// Aligns the pointer to the alignment
529532
template <typename Dtype>
530533
inline Dtype* alignPtr(Dtype* ptr, int64_t alignment) const;
531-
// Creates GemmOptions from kernel and data.
532-
BatchedGemmOptions getOptionsFromConfigAndData(BatchedGemmConfig const& config, BatchedGemmData const& data) const;
533534

534535
// Returns the size of the workspace buffers in bytes
535536
std::vector<size_t> getWorkspaceSizesInBytes(BatchedGemmConfig const& config, BatchedGemmData const& data) const;

cpp/tensorrt_llm/thop/moeOp.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,21 @@ class FusedMoeRunner : public torch::CustomClassHolder
201201
}
202202
switch (mActivationDtype)
203203
{
204+
#ifdef ENABLE_FP8
205+
case c10::ScalarType::Float8_e4m3fn:
206+
{
207+
if (isInt4Quant() and mUseW4GroupScaling)
208+
{
209+
mKernelRunner = std::make_unique<
210+
kernels::CutlassMoeFCRunner<__nv_fp8_e4m3, cutlass::uint4b_t, __nv_bfloat16, __nv_fp8_e4m3>>();
211+
}
212+
else
213+
{
214+
C10_THROW_ERROR_FORMATTED(Error, "FP8 activation type is not supported for non-W4A8 quantization");
215+
}
216+
break;
217+
}
218+
#endif
204219
case c10::ScalarType::Half: mKernelRunner = create_weight_quant_runner<half>(); break;
205220
case c10::ScalarType::BFloat16: mKernelRunner = create_weight_quant_runner<__nv_bfloat16>(); break;
206221
default: C10_THROW_ERROR_FORMATTED(Error, "Unsupported activation type for int-type weight");
4.43 MB
Loading
933 KB
Loading
255 KB
Loading

0 commit comments

Comments
 (0)