Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
*,
multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None,
kv_transfer_params: Optional[dict[str, Any]] = None,
prompt_hidden_states: Optional[torch.Tensor] = None,
# Forward compatibility, code that uses args added in new release can
# still run with older versions of vLLM without breaking.
**kwargs: Any,
Expand All @@ -139,6 +140,7 @@ def __init__(
self.encoder_prompt_token_ids = encoder_prompt_token_ids
self.num_cached_tokens = num_cached_tokens
self.kv_transfer_params = kv_transfer_params
self.prompt_hidden_states = prompt_hidden_states

def add(self, next_output: "RequestOutput", aggregate: bool) -> None:
"""Merge subsequent RequestOutput into this one"""
Expand Down
2 changes: 2 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ class SamplingParams(
response. When set to -1, return all `vocab_size` log probabilities."""
prompt_logprobs: Optional[int] = None
"""Number of log probabilities to return per prompt token."""
return_prompt_hidden_states: bool = False

# NOTE: This parameter is only exposed at the engine level for now.
# It is not exposed in the OpenAI API server, as the OpenAI API does
# not support returning only a list of token IDs.
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/core/sched/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,7 @@
sampled_token_ids = model_runner_output.sampled_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
prompt_hidden_states_dict = model_runner_output.prompt_hidden_states_dict

Check failure on line 848 in vllm/v1/core/sched/scheduler.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/core/sched/scheduler.py:848:81: E501 Line too long (81 > 80)
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
pooler_outputs = model_runner_output.pooler_output
num_nans_in_logits = model_runner_output.num_nans_in_logits
Expand Down Expand Up @@ -932,6 +933,7 @@

# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
prompt_hidden_states = prompt_hidden_states_dict.get(req_id)
if new_token_ids or pooler_output is not None \
or kv_transfer_params:

Expand All @@ -943,6 +945,7 @@
finish_reason=request.get_finished_reason(),
new_logprobs=new_logprobs,
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
prompt_hidden_states=prompt_hidden_states,
pooling_output=pooler_output,
stop_reason=request.stop_reason,
events=request.take_events(),
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class EngineCoreOutput(
new_logprobs: Optional[LogprobsLists] = None
new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None

prompt_hidden_states: Optional[torch.Tensor] = None
pooling_output: Optional[torch.Tensor] = None

finish_reason: Optional[FinishReason] = None
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output)
print("lxy model_output to enginecoreoutput")
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore

Expand Down
66 changes: 66 additions & 0 deletions vllm/v1/engine/hidden_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import itertools
from dataclasses import dataclass
from typing import Optional

import torch

from vllm.logger import init_logger
from vllm.sequence import PromptLogprobs
from vllm.v1.engine import EngineCoreOutput

logger = init_logger(__name__)

NONES = itertools.repeat(None)


@dataclass
class HiddenStatesProcessor:
prompt_hidden_states: Optional[torch.Tensor]

@classmethod
def from_new_request(cls, ) -> "HiddenStatesProcessor":
return cls(prompt_hidden_states=None)

def _set_prompt_hidden_states(
self,
prompt_hidden_states_tensor: torch.Tensor,
) -> None:
"""Update with prompt logprobs from EngineCore.

Args:
prompt_logprobs_tensors: tuple containing the prompt logprobs
tensors.

"""

# We only need to set the prompt hidden states once.
# TODO: check logprobs
assert self.prompt_hidden_states is None

self.prompt_hidden_states = prompt_hidden_states_tensor

def pop_prompt_hidden_states(self) -> Optional[PromptLogprobs]:
"""Pop and return all request prompt logprobs

The logprobs processor aggregates prompt chunk logprobs
over one or more prefill chunks. This method returns
all prompt logprobs at once and then forgets them.
Ensures correct RequestOutputKind.DELTA semantics
wherein all prompt logprobs are returned at once at
the end of prefill.

Returns:
None if prompt logprobs are disabled for this request.
List of all prompt logprobs, otherwise.
"""
plp = self.prompt_hidden_states
if plp:
self.prompt_hidden_states = None
return plp

def update_from_output(self, output: EngineCoreOutput) -> None:
if output.prompt_hidden_states is not None:
print("lxy update_from_output")
self._set_prompt_hidden_states(output.prompt_hidden_states)
1 change: 1 addition & 0 deletions vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]:

# 2) Process EngineCoreOutputs.
iteration_stats = IterationStats() if self.log_stats else None
print("lxy call process_outputs")
processed_outputs = self.output_processor.process_outputs(
outputs.outputs,
engine_core_timestamp=outputs.timestamp,
Expand Down
19 changes: 18 additions & 1 deletion vllm/v1/engine/output_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.hidden_states import HiddenStatesProcessor
from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.engine.parallel_sampling import ParentRequest
from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
Expand Down Expand Up @@ -93,6 +94,7 @@
arrival_time: float,
queue: Optional[RequestOutputCollector],
log_stats: bool,
hidden_states_processor: Optional[HiddenStatesProcessor],
):
self.request_id = request_id
self.parent_req = parent_req
Expand All @@ -111,6 +113,7 @@

self.stats = RequestStateStats(
arrival_time=arrival_time) if log_stats else None
self.hidden_states_processor = hidden_states_processor

@classmethod
def from_new_request(
Expand All @@ -137,10 +140,12 @@
request=request,
)
max_tokens_param = sampling_params.max_tokens
hidden_states_processor = HiddenStatesProcessor.from_new_request()
else:
logprobs_processor = None
detokenizer = None
max_tokens_param = None
hidden_states_processor = None
assert request.pooling_params is not None
output_kind = request.pooling_params.output_kind

Expand All @@ -159,6 +164,7 @@
arrival_time=request.arrival_time,
queue=queue,
log_stats=log_stats,
hidden_states_processor=hidden_states_processor,
)

def make_request_output(
Expand Down Expand Up @@ -204,7 +210,7 @@
finished: bool,
kv_transfer_params: Optional[dict[str, Any]] = None,
) -> Union[RequestOutput, PoolingRequestOutput]:

# Seeems here to process outputs
first_output = outputs[0]
if isinstance(first_output, PoolingOutput):
assert len(outputs) == 1
Expand All @@ -215,17 +221,23 @@
finished=finished,
)
assert self.logprobs_processor is not None
assert self.hidden_states_processor is not None
if self.output_kind == RequestOutputKind.DELTA:
# Side effect: logprobs processor forgets prompt logprobs
prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs()
prompt_hidden_states = self.hidden_states_processor.pop_prompt_hidden_states(

Check failure on line 228 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/engine/output_processor.py:228:81: E501 Line too long (89 > 80)
)
else:
prompt_logprobs = self.logprobs_processor.prompt_logprobs
prompt_hidden_states = self.hidden_states_processor.prompt_hidden_states

Check failure on line 232 in vllm/v1/engine/output_processor.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (E501)

vllm/v1/engine/output_processor.py:232:81: E501 Line too long (84 > 80)

# prompt logprobs is added here
return RequestOutput(
request_id=request_id,
prompt=self.prompt,
prompt_token_ids=self.prompt_token_ids,
prompt_logprobs=prompt_logprobs,
prompt_hidden_states=prompt_hidden_states,
outputs=cast(list[CompletionOutput], outputs),
finished=finished,
kv_transfer_params=kv_transfer_params,
Expand Down Expand Up @@ -399,6 +411,7 @@
kv_transfer_params = engine_core_output.kv_transfer_params
req_state.num_cached_tokens = engine_core_output.num_cached_tokens
req_state.is_prefilling = False
prompt_hidden_states = engine_core_output.prompt_hidden_states

if pooling_output is None:
assert req_state.detokenizer is not None
Expand All @@ -414,8 +427,12 @@
# if required.
req_state.logprobs_processor.update_from_output(
engine_core_output)
assert req_state.hidden_states_processor is not None
req_state.hidden_states_processor.update_from_output(
engine_core_output)

# 4) Create and handle RequestOutput objects.
print("lxy here make_request_output", prompt_hidden_states is None)
if request_output := req_state.make_request_output(
new_token_ids, pooling_output, finish_reason, stop_reason,
kv_transfer_params):
Expand Down
4 changes: 4 additions & 0 deletions vllm/v1/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ class ModelRunnerOutput:
# [prompt_len]
prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]]

# req_id ->
prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]]

# [num_reqs, hidden_size]
pooler_output: list[Optional[torch.Tensor]]

Expand All @@ -128,5 +131,6 @@ class DraftTokenIds:
sampled_token_ids=[],
logprobs=None,
prompt_logprobs_dict={},
prompt_hidden_states_dict={},
pooler_output=[],
num_nans_in_logits=None)
8 changes: 7 additions & 1 deletion vllm/v1/worker/gpu_input_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,12 @@ def __init__(
# NOTE(rob): num_prompt_logprobs only includes reqs
# that are currently in the prefill phase.
self.num_prompt_logprobs: dict[str, int] = {}

# To accumulate prompt logprobs tensor chunks across prefill steps.
self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {}

self.return_prompt_hidden_states_reqs: set[str] = set()
self.in_progress_prompt_hidden_states_cpu: dict[str, torch.Tensor] = {}

# Internal representation of per-step batch state changes, used for
# reordering persistent batch and generating logitsprocs batch state
# updates. Should reset each step.
Expand Down Expand Up @@ -358,6 +360,9 @@ def add_request(
self.num_prompt_logprobs[
req_id] = sampling_params.prompt_logprobs

if sampling_params.return_prompt_hidden_states:
self.return_prompt_hidden_states_reqs.add(req_id)

if sampling_params.allowed_token_ids:
self.has_allowed_token_ids.add(req_id)
if self.allowed_token_ids_mask_cpu_tensor is None:
Expand Down Expand Up @@ -447,6 +452,7 @@ def remove_request(self, req_id: str) -> Optional[int]:
self.num_logprobs.pop(req_id, None)
self.num_prompt_logprobs.pop(req_id, None)
self.in_progress_prompt_logprobs_cpu.pop(req_id, None)
self.in_progress_prompt_hidden_states_cpu.pop(req_id, None)

self.has_allowed_token_ids.discard(req_id)
if self.allowed_token_ids_mask_cpu_tensor is not None:
Expand Down
Loading
Loading