Skip to content

Commit b7d2390

Browse files
author
zijiansh
committed
[Refactor] Refactor to extract model forward logic to allow plug-in to overwrite
1 parent 36c260d commit b7d2390

File tree

1 file changed

+47
-18
lines changed

1 file changed

+47
-18
lines changed

vllm/v1/worker/gpu_model_runner.py

Lines changed: 47 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Iterator
99
from contextlib import contextmanager
1010
from copy import deepcopy
11-
from typing import TYPE_CHECKING, Any, Optional, Union, cast
11+
from typing import Any, cast, Optional, Tuple, TYPE_CHECKING, Union
1212

1313
import numpy as np
1414
import torch
@@ -1452,6 +1452,39 @@ def _pool(
14521452
kv_connector_output=kv_connector_output,
14531453
)
14541454

1455+
def _forward(
1456+
self,
1457+
attn_metadata: dict[str, Any],
1458+
num_input_tokens: int,
1459+
num_tokens_across_dp: int,
1460+
cudagraph_runtime_mode: CUDAGraphMode,
1461+
batch_descriptor: BatchDescriptor,
1462+
scheduler_output: "SchedulerOutput",
1463+
input_ids: torch.Tensor,
1464+
positions: torch.Tensor,
1465+
intermediate_tensors: IntermediateTensors,
1466+
inputs_embeds: list[torch.Tensor],
1467+
model_kwargs: dict[str, Any],
1468+
) -> Tuple[torch.Tensor, Optional[KVConnectorOutput]]:
1469+
with set_forward_context(
1470+
attn_metadata,
1471+
self.vllm_config,
1472+
num_tokens=num_input_tokens,
1473+
num_tokens_across_dp=num_tokens_across_dp,
1474+
cudagraph_runtime_mode=cudagraph_runtime_mode,
1475+
batch_descriptor=batch_descriptor,
1476+
), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output:
1477+
return (
1478+
self.model(
1479+
input_ids=input_ids,
1480+
positions=positions,
1481+
intermediate_tensors=intermediate_tensors,
1482+
inputs_embeds=inputs_embeds,
1483+
**model_kwargs,
1484+
),
1485+
kv_connector_output,
1486+
)
1487+
14551488
@torch.inference_mode()
14561489
def execute_model(
14571490
self,
@@ -1557,23 +1590,19 @@ def execute_model(
15571590

15581591
# Run the model.
15591592
# Use persistent buffers for CUDA graphs.
1560-
with set_forward_context(
1561-
attn_metadata,
1562-
self.vllm_config,
1563-
num_tokens=num_input_tokens,
1564-
num_tokens_across_dp=num_tokens_across_dp,
1565-
cudagraph_runtime_mode=cudagraph_runtime_mode,
1566-
batch_descriptor=batch_descriptor,
1567-
), self.maybe_get_kv_connector_output(
1568-
scheduler_output) as kv_connector_output:
1569-
1570-
model_output = self.model(
1571-
input_ids=input_ids,
1572-
positions=positions,
1573-
intermediate_tensors=intermediate_tensors,
1574-
inputs_embeds=inputs_embeds,
1575-
**model_kwargs,
1576-
)
1593+
model_output, kv_connector_output = self._forward(
1594+
attn_metadata,
1595+
num_input_tokens=num_input_tokens,
1596+
num_tokens_across_dp=num_tokens_across_dp,
1597+
cudagraph_runtime_mode=cudagraph_runtime_mode,
1598+
batch_descriptor=batch_descriptor,
1599+
scheduler_output=scheduler_output,
1600+
input_ids=input_ids,
1601+
positions=positions,
1602+
intermediate_tensors=intermediate_tensors,
1603+
inputs_embeds=inputs_embeds,
1604+
model_kwargs=model_kwargs,
1605+
)
15771606

15781607
if self.use_aux_hidden_state_outputs:
15791608
hidden_states, aux_hidden_states = model_output

0 commit comments

Comments
 (0)