@@ -55,11 +55,12 @@ TrtllmGenBatchedGemmRunner::TrtllmGenBatchedGemmRunner(TrtllmGenBatchedGemmRunne
5555 }
5656 }
5757
58- TLLM_CHECK_WITH_INFO (mPassingConfigIndices .size () != 0 , " No kernel found for the given output type " );
58+ TLLM_CHECK_WITH_INFO (! mPassingConfigIndices .empty () , " No kernel found for the given options " );
5959}
6060
6161size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes (int32_t m, int32_t n, int32_t k,
62- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
62+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim,
63+ std::optional<int32_t > configIndex)
6364{
6465 BatchedGemmData gemmData;
6566 gemmData.mProblemDimensions .mNumBatches = numBatches;
@@ -74,13 +75,18 @@ size_t TrtllmGenBatchedGemmRunner::getWorkspaceSizeInBytes(int32_t m, int32_t n,
7475 gemmData.mProblemDimensions .mWorldSize = 1 ;
7576 gemmData.mProblemDimensions .mMaxNumCtasInTokenDim = maxNumCtasInBatchDim;
7677
77- selectGemmConfig (m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
78-
7978 auto bmm = BatchedGemmInterface ();
79+
8080 auto const configs = bmm.getBatchedGemmConfigs ();
81- TLLM_CHECK_WITH_INFO (
82- mSelectedConfigIndex .has_value (), " No valid kernel found for given param config and problem size" );
83- auto const & config = configs[mSelectedConfigIndex .value ()];
81+
82+ if (!configIndex.has_value ())
83+ {
84+ mSelectedConfigIndex
85+ = getDefaultValidConfigIndex (m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
86+ configIndex = mSelectedConfigIndex ;
87+ }
88+
89+ auto const & config = configs[configIndex.value ()];
8490 return bmm.getWorkspaceSizeInBytes (config, gemmData);
8591}
8692
@@ -89,16 +95,22 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
8995 void const * sfB, void const * perTokensSfA, void const * perTokensSfB, float const * scaleC, float const * scaleGateC,
9096 void * c, void * outSfC, int32_t const * routeMap, int32_t const * totalNumPaddedTokens,
9197 int32_t const * ctaIdxXyToBatchIdx, int32_t const * ctaIdxXyToMnLimit, int32_t const * numNonExitingCtas,
92- void * workspace, CUstream stream, int device)
98+ void * workspace, CUstream stream, int device, std::optional< int32_t > configIndex )
9399{
94100 auto bmm = BatchedGemmInterface ();
95101
96102 BatchedGemmData gemmData;
97103
98104 auto const configs = bmm.getBatchedGemmConfigs ();
99- TLLM_CHECK_WITH_INFO (
100- mSelectedConfigIndex .has_value (), " No valid kernel found for given param config and problem size" );
101- auto const & config = configs[mSelectedConfigIndex .value ()];
105+
106+ if (!configIndex.has_value ())
107+ {
108+ TLLM_CHECK_WITH_INFO (mSelectedConfigIndex .has_value (), " Tried to use default config index but none was set" );
109+
110+ configIndex = mSelectedConfigIndex ;
111+ }
112+
113+ auto const & config = configs[configIndex.value ()];
102114
103115 TLLM_CHECK_WITH_INFO (numBatches > 0 , " Batched GEMM requires numBatches > 0" );
104116 if (!mOptions .staticBatch )
@@ -170,32 +182,33 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vecto
170182
171183void TrtllmGenBatchedGemmRunner::run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens,
172184 void const * a, void const * sfA, void const * b, void const * sfB, void * c, void * outSfC, void * workspace,
173- CUstream stream, int device)
185+ CUstream stream, int device, std::optional< int32_t > configIndex )
174186{
175187 // Dispatch with block scaling factors and with static batching.
176188 run (m, n, k, batchedTokens, /* numTokens */ 0 , batchedTokens.size (), /* maxNumCtasInBatchDim */ 0 , a, sfA, b, sfB,
177189 /* perTokensSfA */ nullptr , /* perTokensSfB */ nullptr ,
178190 /* scaleC */ nullptr , /* scaleGateC */ nullptr , c, outSfC,
179191 /* routeMap */ nullptr , /* totalNumPaddedTokens */ nullptr ,
180192 /* ctaIdxXyToBatchIdx */ nullptr , /* ctaIdxXyToMnLimit */ nullptr ,
181- /* numNonExitingCtas */ nullptr , workspace, stream, device);
193+ /* numNonExitingCtas */ nullptr , workspace, stream, device, configIndex );
182194}
183195
184196void TrtllmGenBatchedGemmRunner::run (int32_t m, int32_t n, int32_t k, std::vector<int32_t > const & batchedTokens,
185197 void const * a, void const * b, float const * scaleC, float const * scaleGateC, void * c, void * workspace,
186- CUstream stream, int device)
198+ CUstream stream, int device, std::optional< int32_t > configIndex )
187199{
188200 // Dispatch with block scaling factors and with static batching.
189201 run (m, n, k, batchedTokens, /* numTokens */ 0 , batchedTokens.size (), /* maxNumCtasInBatchDim */ 0 , a,
190202 /* sfA */ nullptr , b, /* sfB */ nullptr , /* perTokensSfA */ nullptr , /* perTokensSfB */ nullptr , scaleC,
191203 scaleGateC, c, /* outSfC */ nullptr ,
192204 /* routeMap */ nullptr , /* totalNumPaddedTokens */ nullptr ,
193205 /* ctaIdxXyToBatchIdx */ nullptr , /* ctaIdxXyToMnLimit */ nullptr ,
194- /* numNonExitingCtas */ nullptr , workspace, stream, device);
206+ /* numNonExitingCtas */ nullptr , workspace, stream, device, configIndex );
195207}
196208
197- void TrtllmGenBatchedGemmRunner::selectGemmConfig (int32_t m, int32_t n, int32_t k,
198- std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches, int32_t maxNumCtasInBatchDim)
209+ std::vector<int64_t > TrtllmGenBatchedGemmRunner::getValidConfigIndices (int32_t m, int32_t n, int32_t k,
210+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
211+ int32_t maxNumCtasInBatchDim) const
199212{
200213 auto const bmm = BatchedGemmInterface ();
201214 auto const configs = bmm.getBatchedGemmConfigs ();
@@ -242,16 +255,30 @@ void TrtllmGenBatchedGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t
242255 return optionsA.mTileM > optionsB.mTileM ;
243256 });
244257
258+ std::vector<int64_t > validConfigIndices;
245259 for (auto const & configIndex : sortedIndices)
246260 {
247261 auto const & config = configs[configIndex];
248262 auto isValidConfig = bmm.isValidConfig (config, gemmData);
249263 if (isValidConfig)
250264 {
251- mSelectedConfigIndex = configIndex;
252- return ;
265+ validConfigIndices.push_back (configIndex);
253266 }
254267 }
268+
269+ TLLM_CHECK_WITH_INFO (!validConfigIndices.empty (), " No valid config found for the given problem shape" );
270+
271+ return validConfigIndices;
272+ }
273+
274+ int64_t TrtllmGenBatchedGemmRunner::getDefaultValidConfigIndex (int32_t m, int32_t n, int32_t k,
275+ std::vector<int32_t > const & batchedTokens, int32_t numTokens, int32_t numBatches,
276+ int32_t maxNumCtasInBatchDim) const
277+ {
278+ auto const validConfigIndices
279+ = getValidConfigIndices (m, n, k, batchedTokens, numTokens, numBatches, maxNumCtasInBatchDim);
280+
281+ return validConfigIndices[0 ];
255282}
256283
257284} // namespace kernels
0 commit comments