Skip to content

Commit 9a87476

Browse files
authored
Kv cache transfer support duplicate heads (#4929)
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 947571c commit 9a87476

File tree

9 files changed

+226
-95
lines changed

9 files changed

+226
-95
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 121 additions & 40 deletions
Large diffs are not rendered by default.

cpp/tensorrt_llm/batch_manager/cacheFormatter.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ class CacheFormatter final : public IOFormatter
9696
return mCacheManager;
9797
}
9898

99+
static bool needSendCache(CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx);
100+
std::vector<executor::kv_cache::Connection const*> pickRecvConnections(
101+
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
102+
SizeType32 selfIdx, CacheState const& destConfig) const override;
103+
99104
private:
100105
BaseKVCacheManager* mCacheManager{};
101106

cpp/tensorrt_llm/batch_manager/dataTransceiver.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,11 @@ class IOFormatter
6666
CacheState const& selfConfig, SizeType32 selfIdx, CacheState const& destConfig) const
6767
= 0;
6868

69+
[[nodiscard]] virtual std::vector<executor::kv_cache::Connection const*> pickRecvConnections(
70+
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
71+
SizeType32 selfIdx, CacheState const& destConfig) const
72+
= 0;
73+
6974
/// @brief Destructor.
7075
virtual ~IOFormatter() = default;
7176
};

cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,18 +185,28 @@ void DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest)
185185
}
186186
auto counterParts = mFormatter->getCounterparts(
187187
mSelfState.getCacheState().value(), mSelfState.getCommState().value().getSelfIdx(), destCacheState);
188+
189+
auto connections = mManager->getConnections(commState);
190+
std::vector<executor::kv_cache::Connection const*> counterPartConnections;
188191
for (auto index : counterParts)
189192
{
190-
auto const* connection = mManager->getConnections(commState).at(index);
193+
auto const* connection = connections.at(index);
194+
counterPartConnections.emplace_back(connection);
195+
}
196+
auto pickUpConnections = mFormatter->pickRecvConnections(counterPartConnections, mSelfState.getCacheState().value(),
197+
mSelfState.getCommState().value().getSelfIdx(), destCacheState);
198+
for (auto connection : counterPartConnections)
199+
{
191200
// if Manager is agentConnectionManager, then send request info to agent
192201
auto* agentConnectionManager = dynamic_cast<executor::kv_cache::AgentConnectionManager*>(mManager);
193202
if (agentConnectionManager != nullptr)
194203
{
195204
// TODO: index -> validConnectionIdx conversion
205+
auto valideConnectionIdx
206+
= std::find(pickUpConnections.begin(), pickUpConnections.end(), connection) - pickUpConnections.begin();
196207
auto* agentConnection = dynamic_cast<executor::kv_cache::AgentConnection const*>(connection);
197208
TLLM_CHECK(agentConnection != nullptr);
198209
TLLM_CHECK(cacheBufferId.has_value());
199-
int valideConnectionIdx = std::find(counterParts.begin(), counterParts.end(), index) - counterParts.begin();
200210
const_cast<executor::kv_cache::AgentConnection*>(agentConnection)
201211
->sendRequestAndBufferInfo(requestInfo, cacheBufferId, valideConnectionIdx);
202212
}

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager
4141
// some context rank in connection
4242
std::vector<executor::kv_cache::Connection const*> MLACacheFormatter::pickRecvConnections(
4343
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
44-
SizeType32 selfIdx, CacheState const& destConfig)
44+
SizeType32 selfIdx, CacheState const& destConfig) const
4545
{
4646

4747
TLLM_CHECK(!connections.empty());
@@ -469,16 +469,18 @@ void MLACacheFormatter::formatInput(LlmRequest const& llmRequest,
469469
{
470470
if (selfConfig.getDataType() != destConfig.getDataType())
471471
{
472+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same data type");
472473
return false;
473474
}
474475
if (selfConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA
475476
|| destConfig.getAttentionConfig().mAttentionType != CacheState::AttentionType::kMLA)
476477
{
477-
478+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA");
478479
return false;
479480
}
480481
if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor)
481482
{
483+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor");
482484
return false;
483485
}
484486

@@ -487,48 +489,56 @@ void MLACacheFormatter::formatInput(LlmRequest const& llmRequest,
487489

488490
if (setVecSelf.size() != 1)
489491
{
492+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support equal number of heads per layer");
490493
return false;
491494
}
492495
std::unordered_set<int> setVecDest{
493496
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
494497

495498
if (setVecDest.size() != 1)
496499
{
500+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support equal number of heads per layer");
497501
return false;
498502
}
499503
if (selfConfig.getModelConfig().mTokensPerBlock != destConfig.getModelConfig().mTokensPerBlock
500504
|| selfConfig.getModelConfig().mSizePerHead != destConfig.getModelConfig().mSizePerHead)
501505
{
506+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same tokens per block and size per head");
502507
return false;
503508
}
504509
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
505510
{
511+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same number of layers");
506512
return false;
507513
}
508514
if ((selfConfig.getModelConfig().mNbKvHeadsPerLayer.at(0) != 1)
509515
|| (selfConfig.getModelConfig().mNbKvHeadsPerLayer.at(0) != 1))
510516
{
517+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA");
511518
return false;
512519
}
513520

514521
if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor)
515522
{
523+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor");
516524
return false;
517525
}
518526
if (selfConfig.getParallelConfig().mEnableAttentionDP
519527
&& (selfConfig.getParallelConfig().mTensorParallelism % selfConfig.getParallelConfig().mDPsize != 0))
520528
{
521-
529+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
522530
return false;
523531
}
524532
if (destConfig.getParallelConfig().mEnableAttentionDP
525533
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
526534
{
535+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
527536
return false;
528537
}
529538
if ((destConfig.getParallelConfig().mEnableAttentionDP)
530539
&& (destConfig.getParallelConfig().mTensorParallelism != destConfig.getParallelConfig().mDPsize))
531540
{
541+
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be equal to DP size");
532542
return false;
533543
}
534544

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,9 @@ class MLACacheFormatter final : public IOFormatter
7272
}
7373

7474
static bool needSendCache(CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx);
75-
static std::vector<executor::kv_cache::Connection const*> pickRecvConnections(
75+
std::vector<executor::kv_cache::Connection const*> pickRecvConnections(
7676
std::vector<executor::kv_cache::Connection const*> const& connections, CacheState const& selfConfig,
77-
SizeType32 selfIdx, CacheState const& destConfig);
77+
SizeType32 selfIdx, CacheState const& destConfig) const override;
7878

7979
private:
8080
BaseKVCacheManager* mCacheManager{};

cpp/tensorrt_llm/executor/cache_transmission/cacheConcatenate.cu

Lines changed: 29 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ TargetRanksInfo TargetRanksInfoForDP(
8585
? peerCacheState.getParallelConfig().mTensorParallelism / peerCacheState.getParallelConfig().mDPsize
8686
: peerTPNum;
8787

88+
int selfNbHeadsPerLayer = selfCacheState.getModelConfig().mNbKvHeadsPerLayer[0];
89+
int peerNbHeadsPerLayer = peerCacheState.getModelConfig().mNbKvHeadsPerLayer[0];
8890
int selfTPrankInDPGroup = selfTpRank % selfTPSizeOneDPGroup;
8991

9092
{
@@ -112,51 +114,26 @@ TargetRanksInfo TargetRanksInfoForDP(
112114
retRanks.push_back(irank);
113115
}
114116
}
115-
return {mDomainPPSize, mDomainTPSize, std::move(retRanks)};
117+
int mDuplicateHeadFactor = 1;
118+
int mPeerDuplicateHeadFactor = 1;
119+
if (selfNbHeadsPerLayer * selfTPSizeOneDPGroup > peerNbHeadsPerLayer * peerTPSizeOneDPGroup)
120+
{
121+
mDuplicateHeadFactor
122+
= (selfNbHeadsPerLayer * selfTPSizeOneDPGroup) / (peerNbHeadsPerLayer * peerTPSizeOneDPGroup);
123+
}
124+
if (peerNbHeadsPerLayer * peerTPSizeOneDPGroup > selfNbHeadsPerLayer * selfTPSizeOneDPGroup)
125+
{
126+
mPeerDuplicateHeadFactor
127+
= (peerNbHeadsPerLayer * peerTPSizeOneDPGroup) / (selfNbHeadsPerLayer * selfTPSizeOneDPGroup);
128+
}
129+
130+
return {mDomainPPSize, mDomainTPSize, std::move(retRanks), mDuplicateHeadFactor, mPeerDuplicateHeadFactor};
116131
}
117132

118133
TargetRanksInfo targetIRanks(
119134
kv_cache::CacheState const& peerCacheState, kv_cache::CacheState const& selfCacheState, int selfRank)
120135
{
121-
if (selfCacheState.getAttentionConfig().mAttentionType == CacheState::AttentionType::kMLA
122-
|| selfCacheState.getParallelConfig().mEnableAttentionDP
123-
|| peerCacheState.getParallelConfig().mEnableAttentionDP)
124-
{
125-
return TargetRanksInfoForDP(peerCacheState, selfCacheState, selfRank);
126-
}
127-
int iPPNum = peerCacheState.getParallelConfig().mPipelineParallelism; // TODO:
128-
int oPPNum = selfCacheState.getParallelConfig().mPipelineParallelism;
129-
int oNbKvHeads = selfCacheState.getModelConfig().mNbKvHeadsPerLayer[0];
130-
int oNbLayers = selfCacheState.getModelConfig().mNbKvHeadsPerLayer.size() / oPPNum;
131-
int iNbKvHeads = peerCacheState.getModelConfig().mNbKvHeadsPerLayer[0];
132-
int iNbLayers = peerCacheState.getModelConfig().mNbKvHeadsPerLayer.size() / iPPNum;
133-
int oTpRank = selfRank % selfCacheState.getParallelConfig().mTensorParallelism;
134-
int oPpRank = selfRank / selfCacheState.getParallelConfig().mTensorParallelism;
135-
int startHeadId = oTpRank * oNbKvHeads;
136-
int endHeadId = (oTpRank + 1) * oNbKvHeads;
137-
int startLayerId = oPpRank * oNbLayers;
138-
int endLayerId = (oPpRank + 1) * oNbLayers;
139-
int iTpRankStart = startHeadId / iNbKvHeads;
140-
int iTpRankEndInclude = (endHeadId - 1) / iNbKvHeads;
141-
int iPpRankStart = startLayerId / iNbLayers;
142-
int iPpRankEndInclude = (endLayerId - 1) / iNbLayers;
143-
144-
int iTPNum = peerCacheState.getParallelConfig().mTensorParallelism;
145-
std::vector<int> retRanks;
146-
147-
for (int i = iTpRankStart; i <= iTpRankEndInclude; i++)
148-
{
149-
for (int j = iPpRankStart; j <= iPpRankEndInclude; j++)
150-
{
151-
int irank = j * iTPNum + i;
152-
retRanks.push_back(irank);
153-
}
154-
}
155-
// [tp ,pp] order
156-
int mDomainPPSize = iPpRankEndInclude - iPpRankStart + 1;
157-
int mDomainTPSize = iTpRankEndInclude - iTpRankStart + 1;
158-
TLLM_CHECK(!retRanks.empty());
159-
return {mDomainPPSize, mDomainTPSize, std::move(retRanks)};
136+
return TargetRanksInfoForDP(peerCacheState, selfCacheState, selfRank);
160137
}
161138

162139
template <typename T>
@@ -791,6 +768,10 @@ void splitKVCache(std::vector<runtime::ITensor::SharedPtr> const& kVCacheBlocks,
791768
{
792769
outputCacheNum = targetRankInfo.mDomainPPSize;
793770
}
771+
else
772+
{
773+
outputCacheNum = outputCacheNum / targetRankInfo.mPeerDuplicateHeadFactor;
774+
}
794775
TLLM_CHECK(outputCacheNum == outputSplitBlocks.size());
795776
TLLM_CHECK(inputBlockNum > 0);
796777
auto cacheBlockSize = kVCacheBlocks.at(0)->getSize();
@@ -840,7 +821,8 @@ void splitKVCache(std::vector<runtime::ITensor::SharedPtr> const& kVCacheBlocks,
840821
int iTPNum = destCacheState.getParallelConfig().mTensorParallelism;
841822
int oTPNum = selfCacheState.getParallelConfig().mTensorParallelism;
842823
int layerNumDomainPP = numLayers / DomainPPSize;
843-
int headNumDomainTP = headNum / DomainTPSize;
824+
int headNumDomainTP
825+
= headNum / (DomainTPSize / targetRankInfo.mPeerDuplicateHeadFactor); // TODO: duplicate head factor
844826
int kvFactor = selfCacheState.getAttentionConfig().mKvFactor;
845827
bool isMLA = selfCacheState.getAttentionConfig().mAttentionType == CacheState::AttentionType::kMLA;
846828
constexpr int mlaSubWarpSize = 16;
@@ -1017,6 +999,10 @@ void concatenateKVCache(std::vector<runtime::ITensor::SharedPtr> const& inputSpl
1017999
{
10181000
inputCacheNum = targetRankInfo.mDomainPPSize;
10191001
}
1002+
else
1003+
{
1004+
inputCacheNum = inputCacheNum / targetRankInfo.mPeerDuplicateHeadFactor;
1005+
}
10201006
TLLM_CHECK(inputCacheNum == inputSplitBlocks.size());
10211007
TLLM_CHECK(outputBlockNum > 0);
10221008
auto cacheBlockSize = outputKvCacheBlocks.at(0)->getSize();
@@ -1064,7 +1050,8 @@ void concatenateKVCache(std::vector<runtime::ITensor::SharedPtr> const& inputSpl
10641050
int iTPNum = destCacheState.getParallelConfig().mTensorParallelism;
10651051
int oTPNum = selfCacheState.getParallelConfig().mTensorParallelism;
10661052
int layerNumDomainPP = numLayers / DomainPPSize;
1067-
int headNumDomainTP = headNum / DomainTPSize;
1053+
int headNumDomainTP
1054+
= headNum / (DomainTPSize / targetRankInfo.mPeerDuplicateHeadFactor); // TODO: duplicate head factor
10681055
int kvFactor = selfCacheState.getAttentionConfig().mKvFactor;
10691056

10701057
bool isMLA = selfCacheState.getAttentionConfig().mAttentionType == CacheState::AttentionType::kMLA;

cpp/tensorrt_llm/executor/cache_transmission/cacheConcatenate.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ struct TargetRanksInfo
3636
int mDomainPPSize;
3737
int mDomainTPSize;
3838
std::vector<int> mIRanks;
39+
int mDuplicateHeadFactor;
40+
int mPeerDuplicateHeadFactor;
3941
};
4042

4143
TargetRanksInfo targetIRanks(

cpp/tests/batch_manager/cacheTransceiverTest.cpp

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -606,16 +606,24 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
606606
ASSERT_EQ(numLayers % mPpSize, 0);
607607
if (!isMLA)
608608
{
609-
ASSERT_EQ(numHeads % mTpSize, 0);
609+
// ASSERT_EQ(numHeads % mTpSize , 0);
610+
ASSERT_TRUE(numHeads % mTpSize == 0 || mTpSize % numHeads == 0);
610611
}
611612
else
612613
{
613614
ASSERT_EQ(numHeads, 1);
614615
}
615-
int numHeadsPerRank = numHeads / mTpSize;
616+
int numHeadsPerRank = (numHeads + mTpSize - 1) / mTpSize;
617+
mDuplicateHeadFactor = 1;
618+
if (mTpSize > numHeads)
619+
{
620+
mDuplicateHeadFactor = mTpSize / numHeads;
621+
ASSERT_EQ(numHeadsPerRank, 1);
622+
}
616623
if (isMLA || enableDPAttention)
617624
{
618625
numHeadsPerRank = numHeads;
626+
mDuplicateHeadFactor = 1;
619627
}
620628
auto hiddenSize = numHeadsPerRank * sizePerHead;
621629
auto maxBlocksPerSeq = 10;
@@ -656,7 +664,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
656664
DPsize = mTpSize;
657665
}
658666

659-
int numHeadsPerRankForContext = numHeads / mContextTpSize;
667+
int numHeadsPerRankForContext = (numHeads + mContextTpSize - 1) / mContextTpSize;
660668
if (isMLA || mContextDP)
661669
{
662670
numHeadsPerRankForContext = numHeads;
@@ -806,7 +814,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
806814
}
807815
else
808816
{
809-
TLLM_CHECK(false);
817+
TLLM_CHECK_WITH_INFO(false, "Please set at least one cache transfer backend");
810818
}
811819
}
812820

@@ -906,7 +914,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
906914
int layerSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.size() / mPpSize;
907915
int startLayerId = layerSizePerRank * mPpRank;
908916
int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0);
909-
int startHeadId = headSizePerRank * mTpRank;
917+
int startHeadId = headSizePerRank * (mTpRank / mDuplicateHeadFactor);
910918
bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP;
911919
if (mIsMLA || enableDP)
912920
{
@@ -970,7 +978,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
970978
int layerSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.size() / mPpSize;
971979
int startLayerId = layerSizePerRank * mPpRank;
972980
int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0);
973-
int startHeadId = headSizePerRank * mTpRank;
981+
int startHeadId = headSizePerRank * (mTpRank / mDuplicateHeadFactor);
974982
bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP;
975983
if (mIsMLA || enableDP)
976984
{
@@ -1063,6 +1071,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam<AsymmetricTestPara
10631071
bool mContextDP{false};
10641072
bool mGenerationDP{false};
10651073
bool mIsMLA{false};
1074+
int mDuplicateHeadFactor{1};
10661075
SizeType32 mMaxNumSequences{};
10671076
std::unique_ptr<KVCacheManager> mManager;
10681077
std::unique_ptr<CacheTransBufferManager> mCacheTransBufferManager;
@@ -1343,6 +1352,28 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA2, AsymmetricalCacheTest
13431352
testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16),
13441353
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
13451354
testing::Values(false), testing::Values(false), testing::Values(true)));
1355+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0, AsymmetricalCacheTestWithDP,
1356+
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(4), testing::Values(1),
1357+
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
1358+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
1359+
testing::Values(false), testing::Values(true, false), testing::Values(false)));
1360+
1361+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate1, AsymmetricalCacheTestWithDP,
1362+
testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(2), testing::Values(2),
1363+
testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(16),
1364+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
1365+
testing::Values(false), testing::Values(true, false), testing::Values(false)));
1366+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, AsymmetricalCacheTestWithDP,
1367+
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(4, 2), testing::Values(1),
1368+
testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16),
1369+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
1370+
testing::Values(false), testing::Values(false), testing::Values(false)));
1371+
INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP,
1372+
testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1, 2), testing::Values(2),
1373+
testing::Values(4), testing::Values(1, 2), testing::Values(4), testing::Values(16),
1374+
testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2),
1375+
testing::Values(false), testing::Values(false), testing::Values(false)));
1376+
13461377
#endif
13471378

13481379
TEST(targetTest, CacheStateNODP)

0 commit comments

Comments
 (0)