Skip to content

Commit 1d4f748

Browse files
authored
[fix] Fix illegal mem access and possible accuracy lose. Cherry-pick … (#5017)
Signed-off-by: Jin Li <[email protected]>
1 parent f45aff2 commit 1d4f748

File tree

5 files changed

+19
-37
lines changed

5 files changed

+19
-37
lines changed

cpp/include/tensorrt_llm/kernels/kvCachePartialCopy.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace tensorrt_llm
2323
namespace kernels
2424
{
2525
void kvCacheBlockPartialCopy(IBuffer& dst, IBuffer const& src, unsigned int numLayers, unsigned int numHeads,
26-
unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, cudaStream_t stream);
26+
unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, int kvFactor,
27+
cudaStream_t stream);
2728
} // namespace kernels
2829
} // namespace tensorrt_llm

cpp/tensorrt_llm/batch_manager/kvCacheTransferManager.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,15 +141,16 @@ void KVCacheTransferManager::copyBlock(BlockPtr const& src, BlockPtr const& dst,
141141
{
142142
auto stream = (isOffload ? mOffloadManager : mOnboardManager).getStream().get();
143143
int const numLayers = pools[poolIdx].numLayers;
144+
int const kvFactor = pools[poolIdx].kvFactor;
144145
int const numHeads = pools[poolIdx].numKvHeads;
145146
int const sizePerHead = pools[poolIdx].sizePerHead;
146147
auto shape = srcPtr->getShape();
147148

148149
TLLM_CHECK_WITH_INFO(
149150
shape.nbDims == 4, "Expected KVCache block to have 4 dims, got %d", shape.nbDims);
150151

151-
tk::kvCacheBlockPartialCopy(
152-
*dstPtr, *srcPtr, numLayers, numHeads, tokensPerBlock, sizePerHead, numTokensToCopy, stream);
152+
tk::kvCacheBlockPartialCopy(*dstPtr, *srcPtr, numLayers, numHeads, tokensPerBlock, sizePerHead,
153+
numTokensToCopy, kvFactor, stream);
153154
}
154155
}
155156
}

cpp/tensorrt_llm/kernels/kvCachePartialCopy.cu

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ unsigned int ipow2(unsigned int v)
5858

5959
template <typename T>
6060
void hostKVCacheBlockPartialCopy(IBuffer& dst, IBuffer const& src, unsigned int numLayers, unsigned int numHeads,
61-
unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, cudaStream_t stream)
61+
unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, int kvFactor,
62+
cudaStream_t stream)
6263
{
6364
unsigned int blockX = ipow2(numHidden); // ensure block shape is a power of 2
6465
blockX = std::min(blockX, 32u); // blockX should not exceed warp size
@@ -75,55 +76,56 @@ void hostKVCacheBlockPartialCopy(IBuffer& dst, IBuffer const& src, unsigned int
7576
auto srcData = bufferCast<T>(src);
7677
auto dstData = bufferCast<T>(dst);
7778
cuKVCacheBlockPartialCopy<<<grid, block, 0, stream>>>(
78-
dstData, srcData, 2 * numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy);
79+
dstData, srcData, numLayers * kvFactor, numHeads, tokensPerBlock, numHidden, numTokensToCopy);
7980
}
8081
} // namespace
8182

8283
void kvCacheBlockPartialCopy(IBuffer& dst, IBuffer const& src, unsigned int numLayers, unsigned int numHeads,
83-
unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, cudaStream_t stream)
84+
unsigned int tokensPerBlock, unsigned int numHidden, unsigned int numTokensToCopy, int kvFactor,
85+
cudaStream_t stream)
8486
{
8587
auto dataType = src.getDataType();
8688
TLLM_CHECK_WITH_INFO(dataType == dst.getDataType(), "src and dst dataType does not match");
8789
switch (dataType)
8890
{
8991
case nvinfer1::DataType::kINT64:
9092
hostKVCacheBlockPartialCopy<SizeType64>(
91-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
93+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
9294
break;
9395
case nvinfer1::DataType::kINT32:
9496
hostKVCacheBlockPartialCopy<std::int32_t>(
95-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
97+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
9698
break;
9799
case nvinfer1::DataType::kFLOAT:
98100
hostKVCacheBlockPartialCopy<float>(
99-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
101+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
100102
break;
101103
#ifdef ENABLE_BF16
102104
case nvinfer1::DataType::kBF16:
103105
hostKVCacheBlockPartialCopy<__nv_bfloat16>(
104-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
106+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
105107
break;
106108
#endif
107109
case nvinfer1::DataType::kHALF:
108110
hostKVCacheBlockPartialCopy<half>(
109-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
111+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
110112
break;
111113
case nvinfer1::DataType::kBOOL:
112114
hostKVCacheBlockPartialCopy<bool>(
113-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
115+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
114116
break;
115117
case nvinfer1::DataType::kUINT8:
116118
hostKVCacheBlockPartialCopy<std::uint8_t>(
117-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
119+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
118120
break;
119121
case nvinfer1::DataType::kINT8:
120122
hostKVCacheBlockPartialCopy<std::int8_t>(
121-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
123+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
122124
break;
123125
#ifdef ENABLE_FP8
124126
case nvinfer1::DataType::kFP8:
125127
hostKVCacheBlockPartialCopy<__nv_fp8_e4m3>(
126-
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, stream);
128+
dst, src, numLayers, numHeads, tokensPerBlock, numHidden, numTokensToCopy, kvFactor, stream);
127129
break;
128130
#endif
129131
default: TLLM_THROW("Unknown data type");

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -503,8 +503,6 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
503503
[0, pytest.param(2, marks=skip_pre_hopper)])
504504
def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
505505
overlap_scheduler, torch_compile):
506-
if torch_compile:
507-
pytest.skip("https://nvbugs/5292037")
508506
if torch_compile and mtp_nextn > 0:
509507
pytest.skip("https://nvbugs/5252313")
510508
if torch_compile and attention_dp:
@@ -547,8 +545,6 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
547545
def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
548546
attention_dp, cuda_graph, overlap_scheduler,
549547
torch_compile):
550-
if torch_compile:
551-
pytest.skip("https://nvbugs/5292037")
552548
if torch_compile and mtp_nextn > 0:
553549
pytest.skip("https://nvbugs/5252313")
554550
if torch_compile and attention_dp:
@@ -593,8 +589,6 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
593589
@parametrize_with_ids("mtp_nextn", [0, 2])
594590
def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph,
595591
overlap_scheduler, torch_compile):
596-
if torch_compile:
597-
pytest.skip("https://nvbugs/5292037")
598592
if torch_compile and mtp_nextn > 0:
599593
pytest.skip("https://nvbugs/5252313")
600594
if torch_compile and attention_dp:
@@ -712,8 +706,6 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
712706
def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
713707
fp8kv, attention_dp, cuda_graph,
714708
overlap_scheduler, torch_compile):
715-
if torch_compile:
716-
pytest.skip("https://nvbugs/5292037")
717709
if torch_compile and mtp_nextn > 0:
718710
pytest.skip("https://nvbugs/5252313")
719711
if torch_compile and attention_dp:

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://n
358358
accuracy/test_cli_flow.py::TestSantacoder::test_auto_dtype SKIP (https://nvbugs/5231468)
359359
accuracy/test_cli_flow.py::TestLlama3_2_1B::test_cyclic_kv_cache SKIP (https://nvbugs/5231310)
360360
test_e2e.py::test_ptp_quickstart_multimodal[NVILA-8B-FP16-vila/NVILA-8B-image] SKIP (https://nvbugs/5233423)
361-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[tp4-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5294983)
362361
examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-27b-it-fp8-bfloat16-8] SKIP (https://nvbugs/5234164)
363362
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-disable_attention_plugin-disable_context_fmha-tp:1-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5234058)
364363
examples/test_bert.py::test_llm_bert_general[compare_hf-enable_remove_input_padding-disable_attention_plugin-disable_context_fmha-tp:2-pp:1-float16-RobertaForSequenceClassification-bert/twitter-roberta-base-emotion] SKIP (https://nvbugs/5234058)
@@ -382,17 +381,6 @@ triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-deco
382381
triton_server/test_triton.py::test_qwen2_vl[qwen2_vl] SKIP
383382
triton_server/test_triton.py::test_gpt_ib_speculative_decoding_bls[gpt-ib-speculative-decoding-bls] SKIP
384383
triton_server/test_triton_llm.py::test_mistral_v1_multi_models[False-1---False-True-False-0-128-enableDecoupleMode-inflight_fused_batching-disableTrtOverlap--max_utilization-4096--1-1-1-False-ensemble] SKIP
385-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
386-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
387-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
388-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
389-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
390-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
391-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
392-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5285965)
393-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
394-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] SKIP (https://nvbugs/5285965)
395-
accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_fp8_blockscale[latency] SKIP (https://nvbugs/5285965)
396384
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_ngram SKIP (https://nvbugspro.nvidia.com/bug/5324239)
397385
examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (https://nvbugs/5289523)
398386
examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (https://nvbugs/5289523)
@@ -438,8 +426,6 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype
438426
test_e2e.py::test_openai_multi_chat_example SKIP (https://nvbugs/5236980)
439427
test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-70B-FP8-llama-3.1-model/Llama-3.1-70B-Instruct-FP8] SKIP (https://nvbugs/5318059)
440428
test_e2e.py::test_ptq_quickstart_advanced_ngram[Llama-3.1-8B-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct] SKIP (https://nvbugspro.nvidia.com/bug/5324239)
441-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5318087)
442-
accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] SKIP (https://nvbugs/5318087)
443429
unittest/_torch/auto_deploy/integration/test_ad_build.py SKIP (https://nvbugs/5318103)
444430
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=False] SKIP (https://nvbugs/5318143)
445431
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16_4gpus[tp2pp2-attn_backend=TRTLLM-torch_compile=True] SKIP (https://nvbugs/5318143)

0 commit comments

Comments
 (0)