Skip to content

Commit 0640458

Browse files
committed
fix
Signed-off-by: junq <[email protected]>
1 parent 2eb3030 commit 0640458

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

tensorrt_llm/_torch/pyexecutor/cuda_graph_model_engine.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(self, engine: "PyTorchModelEngine"):
6666
self.padding_enabled = config.cuda_graph_padding_enabled
6767
self.supported_batch_sizes = engine._cuda_graph_batch_sizes
6868
self.max_supported_batch_size = engine._max_cuda_graph_batch_size
69+
self.max_num_tokens = engine.max_num_tokens
6970

7071
# Low-level state, storing resources per batch size
7172
self.graphs: Dict[int, torch.cuda.CUDAGraph] = {}
@@ -106,8 +107,12 @@ def execute(self, batch: ScheduledRequests, inputs: Dict[str, Any],
106107

107108
return self._run_graph(batch_size, inputs)
108109

109-
def _capture_graph(self, batch_size: int, forward_fn: Callable,
110-
initial_inputs: Dict[str, Any]):
110+
def _capture_graph(self,
111+
batch_size: int,
112+
forward_fn: Callable,
113+
initial_inputs: Dict[str, Any],
114+
gather_ids: torch.Tensor,
115+
gather_context_logits: bool = False):
111116
"""Captures the forward pass for a given batch size."""
112117
engine = self._get_engine()
113118

@@ -117,13 +122,13 @@ def _capture_graph(self, batch_size: int, forward_fn: Callable,
117122

118123
static_tensors = {
119124
"input_ids":
120-
torch.ones((batch_size * max_tokens_per_req, ),
125+
torch.ones((self.max_num_tokens, ),
121126
device="cuda",
122127
dtype=torch.int32),
123128
"position_ids":
124129
torch.zeros((
125130
1,
126-
batch_size * max_tokens_per_req,
131+
self.max_num_tokens,
127132
),
128133
device="cuda",
129134
dtype=torch.int32),
@@ -144,10 +149,11 @@ def _capture_graph(self, batch_size: int, forward_fn: Callable,
144149
graph = torch.cuda.CUDAGraph()
145150
with capturing_cuda_graph_context():
146151
for _ in range(self.WARMUP_STEPS):
147-
forward_fn(capture_inputs)
152+
forward_fn(capture_inputs, gather_ids, gather_context_logits)
148153

149154
with torch.cuda.graph(graph, pool=self.memory_pool):
150-
output = forward_fn(capture_inputs)
155+
output = forward_fn(capture_inputs, gather_ids,
156+
gather_context_logits)
151157

152158
self.graphs[batch_size] = graph
153159
self.graph_outputs[batch_size] = make_weak_ref(output)

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2018,6 +2018,8 @@ def forward(
20182018
graph_output = self.cuda_graph_model_engine.execute(
20192019
batch=padded_requests,
20202020
inputs=inputs,
2021+
gather_ids=gather_ids,
2022+
gather_context_logits=gather_context_logits,
20212023
forward_fn=self._forward_step)
20222024

20232025
if graph_output is not None:

0 commit comments

Comments
 (0)