2121#include " tensorrt_llm/common/envUtils.h"
2222#include " trtllmGen_bmm_export/BatchedGemmInterface.h"
2323#include " trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
24- // DO NOT include logger.h before BatchedGemmInterface.h as it #undef TLLM_LOG_INFO and co.
24+ // DO NOT include cudaUtils.h and logger.h before BatchedGemmInterface.h as it #undef TLLM_LOG_INFO and co.
25+ #include " tensorrt_llm/common/cudaUtils.h"
2526#include " tensorrt_llm/common/logger.h"
2627
2728namespace tensorrt_llm
@@ -306,6 +307,8 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
306307 auto const bmm = BatchedGemmInterface ();
307308 auto const configs = bmm.getBatchedGemmConfigs ();
308309
310+ int32_t multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount ();
311+
309312 BatchedGemmData gemmData;
310313 // Dims
311314 gemmData.mProblemDimensions .mNumBatches = numBatches;
@@ -319,73 +322,68 @@ std::vector<int64_t> TrtllmGenBatchedGemmRunner::getValidConfigIndices(int32_t m
319322 gemmData.mProblemDimensions .mRank = 0 ;
320323 gemmData.mProblemDimensions .mWorldSize = 1 ;
321324 gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
322- // Tier 0: K < tileK, prefer higher efficiency.
323- auto cmpTier0 = [&configs, &gemmData](int64_t idx0, int64_t idx1)
325+ auto cmpFunc = [&configs, &gemmData, &bmm, &multiProcessorCount](int64_t idx0, int64_t idx1)
324326 {
325327 auto const & optionsA = configs[idx0].mOptions ;
326328 auto const & optionsB = configs[idx1].mOptions ;
327329 int32_t sizeK = gemmData.mProblemDimensions .mK ;
328- // Both waste computation, prefer higher efficiency.
329- if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK )
330- {
331- double eff_a = (double ) sizeK / optionsA.mTileK ;
332- double eff_b = (double ) sizeK / optionsB.mTileK ;
333- return eff_a > eff_b;
334- }
335- // If either can be utilized, sort by tileK.
336- else
330+
331+ // Tier 0: K < tileK, prefer higher efficiency.
332+ if (optionsA.mTileK != optionsB.mTileK )
337333 {
338- return optionsA.mTileK > optionsB.mTileK ;
334+ // Both waste computation, prefer higher efficiency.
335+ if (sizeK <= optionsA.mTileK && sizeK <= optionsB.mTileK )
336+ {
337+ double eff_a = (double ) sizeK / optionsA.mTileK ;
338+ double eff_b = (double ) sizeK / optionsB.mTileK ;
339+ return eff_a > eff_b;
340+ }
341+ // If either can be utilized, sort by tileK.
342+ else
343+ {
344+ return optionsA.mTileK > optionsB.mTileK ;
345+ }
339346 }
340- };
341- // Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
342- auto cmpTier1 = [&configs](int64_t idx0, int64_t idx1)
343- {
344- auto const & optionsA = configs[idx0].mOptions ;
345- auto const & optionsB = configs[idx1].mOptions ;
346- if (optionsA.mTileK == optionsB.mTileK )
347+
348+ // Tier 1: When tileK is the same, prefer unroll loop 2x for mma.
349+ if (optionsA.mUseUnrollLoop2xForMma != optionsB.mUseUnrollLoop2xForMma )
347350 {
348351 return optionsA.mUseUnrollLoop2xForMma ;
349352 }
350- return false ;
351- };
352- // Tier 2+: When previous comparators are the same, prefer higher tileM.
353- auto cmpTier2 = [&configs](int64_t idx0, int64_t idx1)
354- {
355- auto const & optionsA = configs[idx0].mOptions ;
356- auto const & optionsB = configs[idx1].mOptions ;
357- if (optionsA.mTileK == optionsB.mTileK && optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma )
353+
354+ // Tier 2+: When previous comparators are the same, prefer higher tileM.
355+ if (optionsA.mTileM != optionsB.mTileM )
358356 {
359357 return optionsA.mTileM > optionsB.mTileM ;
360358 }
361- return false ;
362- };
363- // Tier 2+: When previous comparators are the same, and when number of estimated CTAs is on the larger side, prefer
364- // persistent tile scheduler. The threshold is hardcoded as >148 CTAs at the moment.
365- auto cmpTier3 = [&configs, &gemmData](int64_t idx0, int64_t idx1)
366- {
367- int32_t sizeM = gemmData.mProblemDimensions .mM ;
368- int32_t sizeN = gemmData.mProblemDimensions .mN ;
369- auto const & optionsA = configs[idx0].mOptions ;
370- auto const & optionsB = configs[idx1].mOptions ;
371- if (optionsA.mTileK == optionsB.mTileK && optionsA.mUseUnrollLoop2xForMma == optionsB.mUseUnrollLoop2xForMma
372- && optionsA.mTileM == optionsB.mTileM )
359+
360+ // Tier 2+: When previous comparators are the same, prefer higher tileN.
361+ if (optionsA.mTileN != optionsB.mTileN )
362+ {
363+ return optionsA.mTileN > optionsB.mTileN ;
364+ }
365+
366+ // Tier 2+: When previous comparators are the same, and when the number of estimated CTAs is on the larger side,
367+ // prefer persistent tile scheduler.
368+ if (optionsA.mTileScheduler != optionsB.mTileScheduler )
373369 {
374- int64_t numTilesM = divUp (sizeM, optionsA. mTileM );
375- int64_t numTilesN = divUp (sizeN, optionsA. mTileN );
376- if (numTilesM * numTilesN > 148 )
370+ auto options = bmm. getOptionsFromConfigAndData (configs[idx0], gemmData );
371+ auto numCtas = bmm. getNumCtas (options, gemmData. mProblemDimensions . mMaxNumCtasInTokenDim );
372+ if (numCtas > multiProcessorCount )
377373 {
378374 return optionsA.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
379375 }
376+ else
377+ {
378+ return optionsB.mTileScheduler == batchedGemm::gemm::TileScheduler::Persistent;
379+ }
380380 }
381+
381382 return false ;
382383 };
383384 // Sort configs by options.
384385 std::vector<int64_t > sortedIndices = mPassingConfigIndices ;
385- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier0);
386- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier1);
387- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier2);
388- std::sort (sortedIndices.begin (), sortedIndices.end (), cmpTier3);
386+ std::sort (sortedIndices.begin (), sortedIndices.end (), cmpFunc);
389387
390388 // Special rules for corner cases, if applicable.
391389 std::vector<int64_t > prioritizedIndices = prioritizePredefinedConfigs (m, n, k, sortedIndices, configs);
0 commit comments