Skip to content

Commit 66030ef

Browse files
authored
[TRTLLM-6452][feat]: Two-model engine KV cache reuse support (#6133)
Signed-off-by: ziyixiong-nv <[email protected]> Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 82d3587 commit 66030ef

File tree

6 files changed

+89
-22
lines changed

6 files changed

+89
-22
lines changed

cpp/include/tensorrt_llm/batch_manager/llmRequest.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,7 @@ class GenericLlmRequest
826826
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
827827
: LlmRequestState::kCONTEXT_INIT;
828828
mContextCurrentPosition = 0;
829+
mPrepopulatedPromptLen = 0;
829830
mContextChunkSize = mPromptLen;
830831
mSeqSlot.reset();
831832
}
@@ -1564,7 +1565,9 @@ class GenericLlmRequest
15641565
/// Returns whether the position is at the beginning of the context.
15651566
[[nodiscard]] bool isFirstContextChunk() const noexcept
15661567
{
1567-
return mContextCurrentPosition == 0;
1568+
// The number of cached token is encountered in mContextCurrentPosition,
1569+
// so the start position of the context is mPrepopulatedPromptLen.
1570+
return mContextCurrentPosition == mPrepopulatedPromptLen;
15681571
}
15691572

15701573
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -258,12 +258,6 @@ def __init__(self,
258258
ResourceManagerType.KV_CACHE_MANAGER)
259259
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0
260260

261-
if self.draft_model_engine is not None and self.kv_cache_manager is not None:
262-
if self.kv_cache_manager.enable_block_reuse:
263-
raise NotImplementedError(
264-
"Draft model engine + KV cache reuse is not supported yet. "
265-
"This will be fixed in the near future!")
266-
267261
self.max_input_len = max_input_len
268262
# _executor_loop private data
269263
self.max_num_active_requests = model_engine.get_max_num_sequences()

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -162,21 +162,6 @@ def _mangle_executor_config(executor_config: ExecutorConfig):
162162
)
163163
executor_config.kv_cache_config.enable_block_reuse = False
164164

165-
spec_config = executor_config.speculative_config
166-
if spec_config is not None and spec_config.spec_dec_mode.has_draft_model():
167-
# The draft and target models have different KV cache managers to support
168-
# different head sizes, dtypes, etc in the generic case.
169-
# However, this line will set context_current_position > 0 if there are
170-
# cached blocks: https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/_torch/pyexecutor/resource_manager.py#L310.
171-
# It actually mutates the LLM request! As a result, when we try to allocate KV cache
172-
# pages for the draft model, is_first_context_chunk returns False and
173-
# no pages are allocated.
174-
# We need to refactor LLMRequest to fix this. Disable block reuse for now.
175-
logger.warning(
176-
f"Disabling block reuse for speculation algorithm {spec_config.spec_dec_mode}"
177-
)
178-
executor_config.kv_cache_config.enable_block_reuse = False
179-
180165
if pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION" and executor_config.enable_chunked_context:
181166
logger.warning(
182167
f"Disabling chunked context for {pytorch_backend_config.attn_backend} backend"

tests/integration/test_lists/test-db/l0_b200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ l0_b200:
5757
- unittest/_torch/modeling -k "modeling_mixtral"
5858
- unittest/_torch/modeling -k "modeling_deepseek"
5959
- unittest/_torch/auto_deploy/unit/singlegpu
60+
- unittest/_torch/speculative/test_eagle3.py
61+
- unittest/_torch/speculative/test_kv_cache_reuse.py
6062
- condition:
6163
ranges:
6264
system_gpu_count:

tests/unittest/_torch/speculative/test_eagle3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
[
1919
[True, "TRTLLM", True, False, False],
2020
[False, "TRTLLM", True, False, False],
21+
[True, "TRTLLM", True, True, False],
22+
[False, "TRTLLM", True, True, False],
2123
[True, "FLASHINFER", True, False, False],
2224
[False, "FLASHINFER", True, False, False],
2325
[False, "TRTLLM", False, True, True],
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import sys
3+
import unittest
4+
5+
import pytest
6+
import torch
7+
from utils.llm_data import llm_models_root
8+
9+
from tensorrt_llm import LLM, SamplingParams
10+
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
11+
KvCacheConfig)
12+
13+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
14+
15+
16+
@pytest.mark.parametrize("use_cuda_graph,attn_backend", [
17+
[True, "TRTLLM"],
18+
[False, "TRTLLM"],
19+
])
20+
@pytest.mark.high_cuda_memory
21+
def test_kv_cache_reuse(use_cuda_graph: bool, attn_backend: str):
22+
# Eagle3 one model works with overlap scheduler and block reuse.
23+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
24+
if total_mem_gb < 35:
25+
pytest.skip("Not enough memory to load target + draft model")
26+
27+
models_path = llm_models_root()
28+
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
29+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
30+
31+
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
32+
# that ref and spec does not match 100%
33+
max_batch_size = 1
34+
max_draft_len = 4
35+
kv_cache_config = KvCacheConfig(enable_block_reuse=True,
36+
free_gpu_memory_fraction=0.5)
37+
cuda_graph_config = CudaGraphConfig(
38+
batch_sizes=[1]) if use_cuda_graph else None
39+
40+
llm_common_config = dict(
41+
model=target_model_dir,
42+
attn_backend=attn_backend,
43+
disable_overlap_scheduler=True,
44+
cuda_graph_config=cuda_graph_config,
45+
max_batch_size=max_batch_size,
46+
kv_cache_config=kv_cache_config,
47+
# This max_seq_len is larger than the one specified
48+
# in the llama 3 8B eagle's config. We want to make sure
49+
# that the draft model won't go above its max in warmup
50+
# in this test.
51+
max_seq_len=8192,
52+
)
53+
54+
spec_config = EagleDecodingConfig(
55+
max_draft_len=max_draft_len,
56+
speculative_model_dir=eagle_model_dir,
57+
eagle3_one_model=False,
58+
)
59+
60+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
61+
62+
# Output tests
63+
prompt = "The future of AI is"
64+
65+
sampling_params = SamplingParams(max_tokens=10, temperature=0)
66+
67+
# First run without KV cache
68+
results = llm_spec.generate(prompt, sampling_params)
69+
generated_text = results.outputs[0].text
70+
71+
# Second run with KV cache
72+
results_kv_cache = llm_spec.generate(prompt, sampling_params)
73+
generated_text_kv_cache = results_kv_cache.outputs[0].text
74+
75+
llm_spec.shutdown()
76+
77+
assert generated_text == generated_text_kv_cache
78+
79+
80+
if __name__ == "__main__":
81+
unittest.main()

0 commit comments

Comments
 (0)