@@ -66,6 +66,7 @@ def __init__(self, engine: "PyTorchModelEngine"):
6666 self .padding_enabled = config .cuda_graph_padding_enabled
6767 self .supported_batch_sizes = engine ._cuda_graph_batch_sizes
6868 self .max_supported_batch_size = engine ._max_cuda_graph_batch_size
69+ self .max_num_tokens = engine .max_num_tokens
6970
7071 # Low-level state, storing resources per batch size
7172 self .graphs : Dict [int , torch .cuda .CUDAGraph ] = {}
@@ -106,8 +107,12 @@ def execute(self, batch: ScheduledRequests, inputs: Dict[str, Any],
106107
107108 return self ._run_graph (batch_size , inputs )
108109
109- def _capture_graph (self , batch_size : int , forward_fn : Callable ,
110- initial_inputs : Dict [str , Any ]):
110+ def _capture_graph (self ,
111+ batch_size : int ,
112+ forward_fn : Callable ,
113+ initial_inputs : Dict [str , Any ],
114+ gather_ids : torch .Tensor ,
115+ gather_context_logits : bool = False ):
111116 """Captures the forward pass for a given batch size."""
112117 engine = self ._get_engine ()
113118
@@ -117,13 +122,13 @@ def _capture_graph(self, batch_size: int, forward_fn: Callable,
117122
118123 static_tensors = {
119124 "input_ids" :
120- torch .ones ((batch_size * max_tokens_per_req , ),
125+ torch .ones ((self . max_num_tokens , ),
121126 device = "cuda" ,
122127 dtype = torch .int32 ),
123128 "position_ids" :
124129 torch .zeros ((
125130 1 ,
126- batch_size * max_tokens_per_req ,
131+ self . max_num_tokens ,
127132 ),
128133 device = "cuda" ,
129134 dtype = torch .int32 ),
@@ -144,10 +149,11 @@ def _capture_graph(self, batch_size: int, forward_fn: Callable,
144149 graph = torch .cuda .CUDAGraph ()
145150 with capturing_cuda_graph_context ():
146151 for _ in range (self .WARMUP_STEPS ):
147- forward_fn (capture_inputs )
152+ forward_fn (capture_inputs , gather_ids , gather_context_logits )
148153
149154 with torch .cuda .graph (graph , pool = self .memory_pool ):
150- output = forward_fn (capture_inputs )
155+ output = forward_fn (capture_inputs , gather_ids ,
156+ gather_context_logits )
151157
152158 self .graphs [batch_size ] = graph
153159 self .graph_outputs [batch_size ] = make_weak_ref (output )
0 commit comments