1010import weakref
1111from abc import ABC , abstractmethod
1212from contextlib import contextmanager
13- from typing import Any , Dict , List , Optional , Tuple
13+ from typing import Any , Callable , Dict , List , Optional , Tuple
1414
1515import torch
1616import torch ._dynamo .config
2121from tensorrt_llm ._torch .pyexecutor .sampler import SampleStateTensors
2222from tensorrt_llm ._torch .speculative import (
2323 get_num_extra_kv_tokens , update_spec_config_from_model_config )
24+ from tensorrt_llm ._torch .speculative .drafting_loops import ChainDrafter
2425from tensorrt_llm ._torch .speculative .mtp import SampleStateTensorsMTP
2526from tensorrt_llm ._utils import (is_trace_enabled , nvtx_range , release_gc ,
2627 str_dtype_to_torch , torch_dtype_to_str ,
@@ -276,6 +277,8 @@ def __init__(
276277 spec_config : Optional ["DecodingBaseConfig" ] = None ,
277278 lora_config : Optional [LoraConfig ] = None ,
278279 is_draft_model : bool = False ,
280+ drafting_loop_wrapper : Optional [Callable [[torch .nn .Module ],
281+ torch .nn .Module ]] = None ,
279282 ):
280283 self .ub_buffers = None
281284 self .batch_size = batch_size
@@ -311,7 +314,8 @@ def __init__(
311314 max_num_tokens = max_num_tokens ,
312315 moe_max_num_tokens = pytorch_backend_config .moe_max_num_tokens ,
313316 moe_load_balancer = pytorch_backend_config .moe_load_balancer ,
314- lora_config = lora_config )
317+ lora_config = lora_config ,
318+ drafting_loop_wrapper = drafting_loop_wrapper )
315319 # In case that some tests use stub models and override `_load_model`.
316320 if not hasattr (self .model , 'extra_attrs' ):
317321 self .model .extra_attrs = {}
@@ -403,7 +407,7 @@ def __init__(
403407 dtype = torch .int ,
404408 device = 'cuda' )
405409 self .without_logits = self .spec_config .spec_dec_mode .without_logits (
406- )
410+ ) or self . model_is_wrapped
407411 self .max_draft_len = spec_config .max_draft_len
408412 else :
409413 self .without_logits = False
@@ -562,21 +566,33 @@ def warmup(self, resource_manager: ResourceManager) -> None:
562566 # Reset the global cuda graph dummy request to None in warmup.
563567 self .cuda_graph_runner .padding_dummy_request = None
564568
569+ def get_num_extra_decoding_steps ():
570+ if isinstance (self .model , ChainDrafter ):
571+ return self .model .max_draft_len
572+ else :
573+ assert not self .model_is_wrapped , (
574+ f"Please add logic to determine num_extra_decoding_steps for drafting loop { type (self .model )} "
575+ )
576+ return 0
577+
565578 def get_cuda_graph_warmup_request (batch_size , draft_len ):
566579 # Divide by max_beam_width to get an approximation of the number of requests that can be run in parallel.
567580 available_blocks = kv_cache_manager .get_num_free_blocks (
568581 ) // self .max_beam_width
569582 if available_blocks >= batch_size :
570583 result = ScheduledRequests ()
571584 result .context_requests = []
585+ num_extra_decoding_steps = get_num_extra_decoding_steps ()
586+
572587 # Add (batch_size - 1) dummy requests with seq_len=1.
573588 # Should only need one more page per request.
574589 requests = kv_cache_manager .add_dummy_requests (
575590 list (range (batch_size - 1 )),
576591 is_gen = True ,
577592 max_num_draft_tokens = draft_len ,
578593 use_mrope = use_mrope ,
579- max_beam_width = self .max_beam_width )
594+ max_beam_width = self .max_beam_width ,
595+ num_extra_decoding_steps = num_extra_decoding_steps )
580596 # Divide by max_beam_width to get an approximation of the number of tokens that can be added to the final request.
581597 available_tokens = kv_cache_manager .get_num_available_tokens (
582598 draft_len )
@@ -592,13 +608,20 @@ def get_cuda_graph_warmup_request(batch_size, draft_len):
592608 if max_position_embeddings is not None :
593609 token_num = min (token_num ,
594610 max_position_embeddings - draft_len )
611+
612+ assert token_num > num_extra_decoding_steps , (
613+ "Cannot fuse drafting loop. We do not have enough KV cache space "
614+ "for all of the draft tokens." )
615+ token_num -= num_extra_decoding_steps
616+
595617 max_seq_len_request = kv_cache_manager .add_dummy_requests (
596618 request_ids = [batch_size - 1 ],
597619 token_nums = [token_num ],
598620 is_gen = True ,
599621 max_num_draft_tokens = draft_len ,
600622 use_mrope = use_mrope ,
601- max_beam_width = self .max_beam_width )[0 ]
623+ max_beam_width = self .max_beam_width ,
624+ num_extra_decoding_steps = num_extra_decoding_steps )[0 ]
602625 # Add the longest request before all other seq_len=1 request to simulate the padding CUDA graph case.
603626 # This batch contains both the longest request and the shortest requests,
604627 # it also contains the maximum number of requests and the maximum token number,
@@ -620,6 +643,13 @@ def get_warmup_request(num_tokens: int, num_gen_tokens: int):
620643 if num_tokens > self .max_num_tokens or num_tokens > available_tokens :
621644 return None
622645
646+ num_extra_decoding_steps = get_num_extra_decoding_steps ()
647+ if num_extra_decoding_steps > 0 :
648+ # Disable autotuning for fused drafting loops for now.
649+ # There are a few bugs that can cause illegal memory accesses
650+ # during warmup.
651+ return None
652+
623653 num_ctx_tokens = num_tokens - num_gen_tokens
624654 num_ctx_requests = 0
625655 ctx_requests = []
@@ -905,6 +935,8 @@ def _load_model(self,
905935 moe_max_num_tokens : Optional [int ] = None ,
906936 moe_load_balancer : Optional [MoeLoadBalancerConfig ] = None ,
907937 lora_config : Optional [LoraConfig ] = None ,
938+ drafting_loop_wrapper : Optional [Callable [
939+ [torch .nn .Module ], torch .nn .Module ]] = None ,
908940 ** kwargs ) -> DecoderModelForCausalLM :
909941 config = checkpoint_loader .load_config (
910942 checkpoint_dir ,
@@ -1008,6 +1040,13 @@ def init_meta_tensor(t: torch.Tensor):
10081040 logger .info ("moe_load_balancer finalize model done" )
10091041
10101042 torch .cuda .current_stream ().synchronize ()
1043+
1044+ if drafting_loop_wrapper is not None :
1045+ model = drafting_loop_wrapper (model )
1046+ self .model_is_wrapped = True
1047+ else :
1048+ self .model_is_wrapped = False
1049+
10111050 return model
10121051
10131052 def _call_load_weights (self , load_method , weights , weight_mapper ):
0 commit comments