@@ -213,14 +213,14 @@ def _capture_graph(self, batch_size: int, forward_fn: Callable,
213213 def _run_graph (self , batch_size : int ,
214214 current_inputs : Dict [str , Any ]) -> Optional [torch .Tensor ]:
215215 """Replays a previously captured graph."""
216- (batch_size , self .draft_len )
217- stored_meta = self .graph_metadata [batch_size ]
216+ key = (batch_size , self .draft_len )
217+ stored_meta = self .graph_metadata [key ]
218218 assert current_inputs ["attn_metadata" ] is stored_meta ["attn_metadata" ]
219219 if stored_meta ["spec_metadata" ] is not None :
220220 assert current_inputs .get (
221221 "spec_metadata" ) is stored_meta ["spec_metadata" ]
222222
223- static_tensors = self .static_inputs [batch_size ]
223+ static_tensors = self .static_inputs [key ]
224224
225225 input_ids = current_inputs ["input_ids" ]
226226 seqlen = input_ids .shape [0 ]
@@ -233,8 +233,8 @@ def _run_graph(self, batch_size: int,
233233 static_tensors ["mrope_position_deltas" ].copy_ (
234234 current_inputs ["mrope_position_deltas" ])
235235
236- self .graphs [batch_size ].replay ()
237- output_ref = self .graph_outputs [batch_size ]
236+ self .graphs [key ].replay ()
237+ output_ref = self .graph_outputs [key ]
238238
239239 return output_ref
240240
0 commit comments