Skip to content
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 47 additions & 18 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Iterator
from contextlib import contextmanager
from copy import deepcopy
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from typing import Any, cast, Optional, Tuple, TYPE_CHECKING, Union

Check failure on line 11 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP035)

vllm/v1/worker/gpu_model_runner.py:11:1: UP035 `typing.Tuple` is deprecated, use `tuple` instead

import numpy as np
import torch
Expand Down Expand Up @@ -1452,6 +1452,39 @@
kv_connector_output=kv_connector_output,
)

def _forward(
self,
attn_metadata: dict[str, Any],
num_input_tokens: int,
num_tokens_across_dp: int,
cudagraph_runtime_mode: CUDAGraphMode,
batch_descriptor: BatchDescriptor,
scheduler_output: "SchedulerOutput",
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors,
inputs_embeds: list[torch.Tensor],
model_kwargs: dict[str, Any],
) -> Tuple[torch.Tensor, Optional[KVConnectorOutput]]:

Check failure on line 1468 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (UP006)

vllm/v1/worker/gpu_model_runner.py:1468:10: UP006 Use `tuple` instead of `Tuple` for type annotation
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return type hint for the model output is torch.Tensor, but the model can also return a tuple of tensors (e.g., when use_aux_hidden_state_outputs is true). This incorrect type hint can cause issues with static type checkers and mislead developers who might extend this class, especially since this method is designed to be overridden by plugins. Using Any will make the type hint correct for all possible return types from the model.

Suggested change
) -> Tuple[torch.Tensor, Optional[KVConnectorOutput]]:
) -> Tuple[Any, Optional[KVConnectorOutput]]:

with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output:
return (
self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
),
kv_connector_output,
)

@torch.inference_mode()
def execute_model(
self,
Expand Down Expand Up @@ -1557,23 +1590,19 @@

# Run the model.
# Use persistent buffers for CUDA graphs.
with set_forward_context(
attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,
batch_descriptor=batch_descriptor,
), self.maybe_get_kv_connector_output(
scheduler_output) as kv_connector_output:

model_output = self.model(
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
**model_kwargs,
)
model_output, kv_connector_output = self._forward(
attn_metadata,
num_input_tokens=num_input_tokens,
num_tokens_across_dp=num_tokens_across_dp,
cudagraph_runtime_mode=cudagraph_runtime_mode,

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Any | None"; expected "int" [arg-type]

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Any | None"; expected "int" [arg-type]

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Any | None"; expected "int" [arg-type]

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Any | None"; expected "int" [arg-type]

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Optional[Any]"; expected "int" [arg-type]

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Optional[Any]"; expected "int" [arg-type]

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Optional[Any]"; expected "int" [arg-type]

Check failure on line 1597 in vllm/v1/worker/gpu_model_runner.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument "num_tokens_across_dp" to "_forward" of "GPUModelRunner" has incompatible type "Optional[Any]"; expected "int" [arg-type]
batch_descriptor=batch_descriptor,
scheduler_output=scheduler_output,
input_ids=input_ids,
positions=positions,
intermediate_tensors=intermediate_tensors,
inputs_embeds=inputs_embeds,
model_kwargs=model_kwargs,
)

if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
Expand Down