@@ -58,7 +58,8 @@ unsigned int ipow2(unsigned int v)
5858
5959template <typename T>
6060void hostKVCacheBlockPartialCopy (IBuffer& dst, IBuffer const & src, unsigned int numLayers, unsigned int numHeads,
61- unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, cudaStream_t stream)
61+ unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, int kvFactor,
62+ cudaStream_t stream)
6263{
6364 unsigned int blockX = ipow2 (numHidden); // ensure block shape is a power of 2
6465 blockX = std::min (blockX, 32u ); // blockX should not exceed warp size
@@ -75,55 +76,56 @@ void hostKVCacheBlockPartialCopy(IBuffer& dst, IBuffer const& src, unsigned int
7576 auto srcData = bufferCast<T>(src);
7677 auto dstData = bufferCast<T>(dst);
7778 cuKVCacheBlockPartialCopy<<<grid, block, 0 , stream>>> (
78- dstData, srcData, 2 * numLayers , numHeads, tokensPerBlock, numHidden, numTokensToCopy);
79+ dstData, srcData, numLayers * kvFactor , numHeads, tokensPerBlock, numHidden, numTokensToCopy);
7980}
8081} // namespace
8182
8283void kvCacheBlockPartialCopy (IBuffer& dst, IBuffer const & src, unsigned int numLayers, unsigned int numHeads,
83- unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, cudaStream_t stream)
84+ unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, int kvFactor,
85+ cudaStream_t stream)
8486{
8587 auto dataType = src.getDataType ();
8688 TLLM_CHECK_WITH_INFO (dataType == dst.getDataType (), " src and dst dataType does not match" );
8789 switch (dataType)
8890 {
8991 case nvinfer1::DataType::kINT64 :
9092 hostKVCacheBlockPartialCopy<SizeType64>(
91- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
93+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
9294 break ;
9395 case nvinfer1::DataType::kINT32 :
9496 hostKVCacheBlockPartialCopy<std::int32_t >(
95- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
97+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
9698 break ;
9799 case nvinfer1::DataType::kFLOAT :
98100 hostKVCacheBlockPartialCopy<float >(
99- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
101+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
100102 break ;
101103#ifdef ENABLE_BF16
102104 case nvinfer1::DataType::kBF16 :
103105 hostKVCacheBlockPartialCopy<__nv_bfloat16>(
104- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
106+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
105107 break ;
106108#endif
107109 case nvinfer1::DataType::kHALF :
108110 hostKVCacheBlockPartialCopy<half>(
109- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
111+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
110112 break ;
111113 case nvinfer1::DataType::kBOOL :
112114 hostKVCacheBlockPartialCopy<bool >(
113- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
115+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
114116 break ;
115117 case nvinfer1::DataType::kUINT8 :
116118 hostKVCacheBlockPartialCopy<std::uint8_t >(
117- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
119+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
118120 break ;
119121 case nvinfer1::DataType::kINT8 :
120122 hostKVCacheBlockPartialCopy<std::int8_t >(
121- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
123+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
122124 break ;
123125#ifdef ENABLE_FP8
124126 case nvinfer1::DataType::kFP8 :
125127 hostKVCacheBlockPartialCopy<__nv_fp8_e4m3>(
126- dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
128+ dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
127129 break ;
128130#endif
129131 default : TLLM_THROW (" Unknown data type" );
0 commit comments