Skip to content

Commit e20fbab

Browse files
committed
fix
Signed-off-by: junq <[email protected]>
1 parent c13d9fb commit e20fbab

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,14 +213,14 @@ def _capture_graph(self, batch_size: int, forward_fn: Callable,
213213
def _run_graph(self, batch_size: int,
214214
current_inputs: Dict[str, Any]) -> Optional[torch.Tensor]:
215215
"""Replays a previously captured graph."""
216-
(batch_size, self.draft_len)
217-
stored_meta = self.graph_metadata[batch_size]
216+
key = (batch_size, self.draft_len)
217+
stored_meta = self.graph_metadata[key]
218218
assert current_inputs["attn_metadata"] is stored_meta["attn_metadata"]
219219
if stored_meta["spec_metadata"] is not None:
220220
assert current_inputs.get(
221221
"spec_metadata") is stored_meta["spec_metadata"]
222222

223-
static_tensors = self.static_inputs[batch_size]
223+
static_tensors = self.static_inputs[key]
224224

225225
input_ids = current_inputs["input_ids"]
226226
seqlen = input_ids.shape[0]
@@ -233,8 +233,8 @@ def _run_graph(self, batch_size: int,
233233
static_tensors["mrope_position_deltas"].copy_(
234234
current_inputs["mrope_position_deltas"])
235235

236-
self.graphs[batch_size].replay()
237-
output_ref = self.graph_outputs[batch_size]
236+
self.graphs[key].replay()
237+
output_ref = self.graph_outputs[key]
238238

239239
return output_ref
240240

0 commit comments

Comments
 (0)