Skip to content

Commit b3c57a7

Browse files
authored
[TRTLLM-7353][feat] Implement capturable drafting loops for speculation (#7100)
Signed-off-by: Mike Iovine <[email protected]>
1 parent 01dfd3a commit b3c57a7

File tree

7 files changed

+260
-24
lines changed

7 files changed

+260
-24
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any, Dict, Generic, Optional, Tuple
1+
from typing import Dict, Generic, Optional, Tuple
22

33
import torch
44
from torch import nn
@@ -293,18 +293,6 @@ def load_weights_from_target_model(self,
293293
if self.load_lm_head_from_target:
294294
self.lm_head = target_model.lm_head
295295

296-
# TODO: should input/position IDs be included in this? Keeping it implicit
297-
# for now since the shapes/dtypes are the same across all models we have.
298-
def get_warmup_extra_inputs(self, batch_size: int,
299-
num_tokens: int) -> Dict[str, Any]:
300-
301-
hidden_states = torch.empty(batch_size * num_tokens,
302-
self.model.hidden_size,
303-
dtype=self.model.dtype,
304-
device='cuda')
305-
306-
return {'hidden_states': hidden_states}
307-
308296
def apply_eagle3_fc(self, hidden_states: torch.Tensor) -> torch.Tensor:
309297
"""
310298
Hack for eagle3. We might need to run a matmul to reduce

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import weakref
1111
from abc import ABC, abstractmethod
1212
from contextlib import contextmanager
13-
from typing import Any, Dict, List, Optional, Tuple
13+
from typing import Any, Callable, Dict, List, Optional, Tuple
1414

1515
import torch
1616
import torch._dynamo.config
@@ -21,6 +21,7 @@
2121
from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors
2222
from tensorrt_llm._torch.speculative import (
2323
get_num_extra_kv_tokens, update_spec_config_from_model_config)
24+
from tensorrt_llm._torch.speculative.drafting_loops import ChainDrafter
2425
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
2526
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
2627
str_dtype_to_torch, torch_dtype_to_str,
@@ -276,6 +277,8 @@ def __init__(
276277
spec_config: Optional["DecodingBaseConfig"] = None,
277278
lora_config: Optional[LoraConfig] = None,
278279
is_draft_model: bool = False,
280+
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
281+
torch.nn.Module]] = None,
279282
):
280283
self.ub_buffers = None
281284
self.batch_size = batch_size
@@ -311,7 +314,8 @@ def __init__(
311314
max_num_tokens=max_num_tokens,
312315
moe_max_num_tokens=pytorch_backend_config.moe_max_num_tokens,
313316
moe_load_balancer=pytorch_backend_config.moe_load_balancer,
314-
lora_config=lora_config)
317+
lora_config=lora_config,
318+
drafting_loop_wrapper=drafting_loop_wrapper)
315319
# In case that some tests use stub models and override `_load_model`.
316320
if not hasattr(self.model, 'extra_attrs'):
317321
self.model.extra_attrs = {}
@@ -403,7 +407,7 @@ def __init__(
403407
dtype=torch.int,
404408
device='cuda')
405409
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
406-
)
410+
) or self.model_is_wrapped
407411
self.max_draft_len = spec_config.max_draft_len
408412
else:
409413
self.without_logits = False
@@ -562,21 +566,33 @@ def warmup(self, resource_manager: ResourceManager) -> None:
562566
# Reset the global cuda graph dummy request to None in warmup.
563567
self.cuda_graph_runner.padding_dummy_request = None
564568

569+
def get_num_extra_decoding_steps():
570+
if isinstance(self.model, ChainDrafter):
571+
return self.model.max_draft_len
572+
else:
573+
assert not self.model_is_wrapped, (
574+
f"Please add logic to determine num_extra_decoding_steps for drafting loop {type(self.model)}"
575+
)
576+
return 0
577+
565578
def get_cuda_graph_warmup_request(batch_size, draft_len):
566579
# Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
567580
available_blocks = kv_cache_manager.get_num_free_blocks(
568581
) // self.max_beam_width
569582
if available_blocks >= batch_size:
570583
result = ScheduledRequests()
571584
result.context_requests = []
585+
num_extra_decoding_steps = get_num_extra_decoding_steps()
586+
572587
# Add (batch_size - 1) dummy requests with seq_len=1.
573588
# Should only need one more page per request.
574589
requests = kv_cache_manager.add_dummy_requests(
575590
list(range(batch_size - 1)),
576591
is_gen=True,
577592
max_num_draft_tokens=draft_len,
578593
use_mrope=use_mrope,
579-
max_beam_width=self.max_beam_width)
594+
max_beam_width=self.max_beam_width,
595+
num_extra_decoding_steps=num_extra_decoding_steps)
580596
# Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
581597
available_tokens = kv_cache_manager.get_num_available_tokens(
582598
draft_len)
@@ -592,13 +608,20 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
592608
if max_position_embeddings is not None:
593609
token_num = min(token_num,
594610
max_position_embeddings - draft_len)
611+
612+
assert token_num > num_extra_decoding_steps, (
613+
"Cannot fuse drafting loop. We do not have enough KV cache space "
614+
"for all of the draft tokens.")
615+
token_num -= num_extra_decoding_steps
616+
595617
max_seq_len_request = kv_cache_manager.add_dummy_requests(
596618
request_ids=[batch_size - 1],
597619
token_nums=[token_num],
598620
is_gen=True,
599621
max_num_draft_tokens=draft_len,
600622
use_mrope=use_mrope,
601-
max_beam_width=self.max_beam_width)[0]
623+
max_beam_width=self.max_beam_width,
624+
num_extra_decoding_steps=num_extra_decoding_steps)[0]
602625
# Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
603626
# This batch contains both the longest request and the shortest requests,
604627
# it also contains the maximum number of requests and the maximum token number,
@@ -620,6 +643,13 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
620643
if num_tokens > self.max_num_tokens or num_tokens > available_tokens:
621644
return None
622645

646+
num_extra_decoding_steps = get_num_extra_decoding_steps()
647+
if num_extra_decoding_steps > 0:
648+
# Disable autotuning for fused drafting loops for now.
649+
# There are a few bugs that can cause illegal memory accesses
650+
# during warmup.
651+
return None
652+
623653
num_ctx_tokens = num_tokens - num_gen_tokens
624654
num_ctx_requests = 0
625655
ctx_requests = []
@@ -905,6 +935,8 @@ def _load_model(self,
905935
moe_max_num_tokens: Optional[int] = None,
906936
moe_load_balancer: Optional[MoeLoadBalancerConfig] = None,
907937
lora_config: Optional[LoraConfig] = None,
938+
drafting_loop_wrapper: Optional[Callable[
939+
[torch.nn.Module], torch.nn.Module]] = None,
908940
**kwargs) -> DecoderModelForCausalLM:
909941
config = checkpoint_loader.load_config(
910942
checkpoint_dir,
@@ -1008,6 +1040,13 @@ def init_meta_tensor(t: torch.Tensor):
10081040
logger.info("moe_load_balancer finalize model done")
10091041

10101042
torch.cuda.current_stream().synchronize()
1043+
1044+
if drafting_loop_wrapper is not None:
1045+
model = drafting_loop_wrapper(model)
1046+
self.model_is_wrapped = True
1047+
else:
1048+
self.model_is_wrapped = False
1049+
10111050
return model
10121051

10131052
def _call_load_weights(self, load_method, weights, weight_mapper):

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,29 @@ def create_py_executor(
260260
with mem_monitor.observe_creation_stage(
261261
_ExecutorCreationStage.MODEL_ENGINE_DRAFT):
262262
draft_spec_config = copy.copy(spec_config)
263-
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
264-
if spec_config.load_format == "dummy":
265-
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
266263
# The draft model won't have any draft tokens attached to
267264
# generation requests when we invoke it autoregressively
268265
draft_spec_config.max_draft_len = 0
269266

267+
use_chain_drafter = (
268+
executor_config.guided_decoding_config is None
269+
and not pytorch_backend_config.enable_mixed_sampler
270+
and pytorch_backend_config.attn_backend == "TRTLLM")
271+
272+
if use_chain_drafter:
273+
274+
def drafting_loop_wrapper(model):
275+
from tensorrt_llm._torch.speculative.drafting_loops import \
276+
ChainDrafter
277+
278+
return ChainDrafter(spec_config.max_draft_len, model)
279+
else:
280+
drafting_loop_wrapper = None
281+
282+
draft_pytorch_backend_config = copy.copy(pytorch_backend_config)
283+
if spec_config.load_format == "dummy":
284+
draft_pytorch_backend_config.load_format = LoadFormat.DUMMY
285+
270286
draft_model_engine = PyTorchModelEngine(
271287
model_path=spec_config.speculative_model_dir,
272288
pytorch_backend_config=draft_pytorch_backend_config,
@@ -282,6 +298,7 @@ def create_py_executor(
282298
spec_config=draft_spec_config,
283299
checkpoint_loader=executor_config.checkpoint_loader,
284300
is_draft_model=True,
301+
drafting_loop_wrapper=drafting_loop_wrapper,
285302
)
286303
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
287304
draft_model_engine.load_weights_from_target_model(

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,6 +469,11 @@ def add_dummy_requests(
469469
max_num_draft_tokens: int = 0,
470470
use_mrope: bool = False,
471471
max_beam_width: int = 1,
472+
# For capturable drafting loops. During normal inference, the draft model always
473+
# has enough KV cache space to fit all of our draft tokens. During warmup, however,
474+
# we need to make the KV cache manager aware that multiple autoregressive steps will
475+
# occur.
476+
num_extra_decoding_steps: int = 0,
472477
):
473478
beam_width = max_beam_width
474479
requests = []
@@ -502,6 +507,10 @@ def add_dummy_requests(
502507
self.impl.add_sequence(req_id, token_num, beam_width, req)
503508
for _ in range(self.num_extra_kv_tokens):
504509
self.impl.add_token(req_id)
510+
511+
for _ in range(num_extra_decoding_steps):
512+
self.impl.add_token(req_id)
513+
505514
if is_gen:
506515
req.state = LlmRequestState.GENERATION_IN_PROGRESS
507516
req.prompt_len = token_num - 1
@@ -510,6 +519,7 @@ def add_dummy_requests(
510519
if prepare_resource:
511520
for _ in range(max_num_draft_tokens):
512521
self.impl.add_token(req_id)
522+
513523
requests.append(req)
514524
return requests
515525

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""
2+
This module contains capturable drafting loops for speculative decoding.
3+
4+
These are torch modules wrap another draft model. The wrapped module
5+
is supposed to invoke the draft model autoregressively and invoke
6+
a sampling algorithm to obtain draft tokens. By structuring the code
7+
like this, we are able to avoid host overhead: the entire drafting process
8+
for speculation can be launched as a single CUDA graph.
9+
"""
10+
11+
from contextlib import contextmanager
12+
13+
import torch
14+
15+
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
16+
from tensorrt_llm._torch.speculative.eagle3 import Eagle3SpecMetadata
17+
from tensorrt_llm._torch.speculative.interface import SpecMetadata
18+
19+
20+
@contextmanager
21+
def save_metadata_state(attn_metadata: AttentionMetadata,
22+
spec_metadata: SpecMetadata) -> None:
23+
batch_size = attn_metadata.num_seqs
24+
25+
if attn_metadata.is_cuda_graph:
26+
seq_len = attn_metadata._seq_lens[:batch_size].clone()
27+
seq_len_cuda = attn_metadata._seq_lens_cuda[:batch_size].clone()
28+
kv_lens = attn_metadata.kv_lens_cuda.clone()
29+
30+
assert spec_metadata.is_cuda_graph
31+
num_tokens = spec_metadata.num_tokens
32+
if isinstance(spec_metadata, Eagle3SpecMetadata):
33+
read_indices = spec_metadata.hidden_states_read_indices[:
34+
batch_size].clone(
35+
)
36+
write_indices = spec_metadata.hidden_states_write_indices[:
37+
batch_size].clone(
38+
)
39+
40+
try:
41+
yield
42+
finally:
43+
if attn_metadata.is_cuda_graph:
44+
attn_metadata._seq_lens[:batch_size].copy_(seq_len[:batch_size])
45+
attn_metadata._seq_lens_cuda[:batch_size].copy_(
46+
seq_len_cuda[:batch_size])
47+
attn_metadata.kv_lens_cuda[:batch_size].copy_(kv_lens[:batch_size])
48+
49+
spec_metadata.num_tokens = num_tokens
50+
if isinstance(spec_metadata, Eagle3SpecMetadata):
51+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
52+
read_indices)
53+
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
54+
write_indices)
55+
56+
# This restore has to happen even if the spec_metadata is not being used
57+
# for CUDA graphs. It won't be reset by spec_metadata.prepare().
58+
if isinstance(spec_metadata, Eagle3SpecMetadata):
59+
spec_metadata.is_first_draft = True
60+
spec_metadata.eagle3_resource_manager.is_first_draft = True
61+
62+
63+
def prepare_for_generation(attn_metadata: AttentionMetadata,
64+
spec_metadata: SpecMetadata,
65+
last_tokens_idx: torch.Tensor) -> None:
66+
batch_size = attn_metadata.num_seqs
67+
attn_metadata._seq_lens[:batch_size].fill_(1)
68+
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
69+
attn_metadata.on_update()
70+
attn_metadata.kv_lens_cuda[:batch_size] += 1
71+
72+
attn_metadata.host_request_types[:attn_metadata.num_contexts].fill_(1)
73+
attn_metadata.num_contexts = 0
74+
75+
spec_metadata.num_tokens = batch_size
76+
77+
if isinstance(spec_metadata, Eagle3SpecMetadata):
78+
spec_metadata.eagle3_resource_manager.is_first_draft = False
79+
spec_metadata.is_first_draft = False
80+
81+
old_write_indices = spec_metadata.hidden_states_write_indices
82+
83+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
84+
old_write_indices[last_tokens_idx])
85+
spec_metadata.hidden_states_write_indices[:batch_size].copy_(
86+
torch.arange(
87+
batch_size,
88+
dtype=spec_metadata.hidden_states_write_indices.dtype,
89+
device=spec_metadata.hidden_states_write_indices.device))
90+
91+
92+
class ChainDrafter(torch.nn.Module):
93+
94+
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
95+
super().__init__()
96+
self.draft_model = draft_model
97+
self.config = self.draft_model.config
98+
self.model_config = self.draft_model.model_config
99+
self.max_draft_len = max_draft_len
100+
101+
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
102+
attn_metadata: AttentionMetadata,
103+
spec_metadata: AttentionMetadata, **kwargs) -> None:
104+
105+
logits = self.draft_model.forward(input_ids=input_ids,
106+
position_ids=position_ids,
107+
attn_metadata=attn_metadata,
108+
spec_metadata=spec_metadata)
109+
110+
new_draft_tokens = [self.sample(logits)]
111+
112+
with save_metadata_state(attn_metadata, spec_metadata):
113+
batch_size = attn_metadata.num_seqs
114+
last_tokens_idx = torch.cumsum(
115+
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
116+
new_position_ids = position_ids[0, last_tokens_idx] + 1
117+
118+
prepare_for_generation(attn_metadata, spec_metadata,
119+
last_tokens_idx)
120+
121+
for i in range(self.max_draft_len - 1):
122+
logits = self.draft_model.forward(
123+
input_ids=new_draft_tokens[-1],
124+
position_ids=new_position_ids,
125+
attn_metadata=attn_metadata,
126+
spec_metadata=spec_metadata)
127+
new_draft_tokens.append(self.sample(logits))
128+
new_position_ids += 1
129+
attn_metadata.kv_lens_cuda[:batch_size] += 1
130+
if i == 0 and isinstance(spec_metadata, Eagle3SpecMetadata):
131+
spec_metadata.hidden_states_read_indices[:batch_size].copy_(
132+
spec_metadata.hidden_states_write_indices[:batch_size])
133+
134+
return torch.stack(new_draft_tokens)
135+
136+
def sample(self, logits: torch.Tensor) -> torch.Tensor:
137+
# TODO: inject the sampler here so we can support non-greedy
138+
tokens = torch.argmax(logits, dim=-1)
139+
if hasattr(self.draft_model.model, "d2t"):
140+
d2t = self.draft_model.model.d2t.data
141+
return tokens + d2t[tokens]
142+
143+
return tokens
144+
145+
def load_weights_from_target_model(self,
146+
target_model: torch.nn.Module) -> None:
147+
loader = getattr(self.draft_model, "load_weights_from_target_model",
148+
None)
149+
if callable(loader):
150+
self.draft_model.load_weights_from_target_model(target_model)

0 commit comments

Comments
 (0)