diff --git a/vllm/outputs.py b/vllm/outputs.py index 64bcfd472f2a..97b8a8b8a92e 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -119,6 +119,7 @@ def __init__( *, multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None, kv_transfer_params: Optional[dict[str, Any]] = None, + prompt_hidden_states: Optional[torch.Tensor] = None, # Forward compatibility, code that uses args added in new release can # still run with older versions of vLLM without breaking. **kwargs: Any, @@ -139,6 +140,7 @@ def __init__( self.encoder_prompt_token_ids = encoder_prompt_token_ids self.num_cached_tokens = num_cached_tokens self.kv_transfer_params = kv_transfer_params + self.prompt_hidden_states = prompt_hidden_states def add(self, next_output: "RequestOutput", aggregate: bool) -> None: """Merge subsequent RequestOutput into this one""" diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index c7b4ba34c602..9a21d8cdbeea 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -166,6 +166,8 @@ class SamplingParams( response. When set to -1, return all `vocab_size` log probabilities.""" prompt_logprobs: Optional[int] = None """Number of log probabilities to return per prompt token.""" + return_prompt_hidden_states: bool = False + # NOTE: This parameter is only exposed at the engine level for now. # It is not exposed in the OpenAI API server, as the OpenAI API does # not support returning only a list of token IDs. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 8322fa7335b6..25482be7747f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -845,6 +845,7 @@ def update_from_output( sampled_token_ids = model_runner_output.sampled_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict + prompt_hidden_states_dict = model_runner_output.prompt_hidden_states_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens pooler_outputs = model_runner_output.pooler_output num_nans_in_logits = model_runner_output.num_nans_in_logits @@ -932,6 +933,7 @@ def update_from_output( # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) + prompt_hidden_states = prompt_hidden_states_dict.get(req_id) if new_token_ids or pooler_output is not None \ or kv_transfer_params: @@ -943,6 +945,7 @@ def update_from_output( finish_reason=request.get_finished_reason(), new_logprobs=new_logprobs, new_prompt_logprobs_tensors=prompt_logprobs_tensors, + prompt_hidden_states=prompt_hidden_states, pooling_output=pooler_output, stop_reason=request.stop_reason, events=request.take_events(), diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index 5d8959a3cd3f..238e3f955734 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -104,6 +104,7 @@ class EngineCoreOutput( new_logprobs: Optional[LogprobsLists] = None new_prompt_logprobs_tensors: Optional[LogprobsTensors] = None + prompt_hidden_states: Optional[torch.Tensor] = None pooling_output: Optional[torch.Tensor] = None finish_reason: Optional[FinishReason] = None diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 922c06b44be8..104bfd4b9fdf 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -291,6 +291,7 @@ def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]: model_output = self.execute_model_with_error_logging( self.model_executor.execute_model, # type: ignore scheduler_output) + print("lxy model_output to enginecoreoutput") engine_core_outputs = self.scheduler.update_from_output( scheduler_output, model_output) # type: ignore diff --git a/vllm/v1/engine/hidden_states.py b/vllm/v1/engine/hidden_states.py new file mode 100644 index 000000000000..7b0e15604e61 --- /dev/null +++ b/vllm/v1/engine/hidden_states.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import itertools +from dataclasses import dataclass +from typing import Optional + +import torch + +from vllm.logger import init_logger +from vllm.sequence import PromptLogprobs +from vllm.v1.engine import EngineCoreOutput + +logger = init_logger(__name__) + +NONES = itertools.repeat(None) + + +@dataclass +class HiddenStatesProcessor: + prompt_hidden_states: Optional[torch.Tensor] + + @classmethod + def from_new_request(cls, ) -> "HiddenStatesProcessor": + return cls(prompt_hidden_states=None) + + def _set_prompt_hidden_states( + self, + prompt_hidden_states_tensor: torch.Tensor, + ) -> None: + # We only need to set the prompt hidden states once. + assert self.prompt_hidden_states is None + + self.prompt_hidden_states = prompt_hidden_states_tensor + + def pop_prompt_hidden_states(self) -> Optional[PromptLogprobs]: + """Pop and return all request prompt hidden states + + The hidden states processor aggregates prompt chunk hidden states + over one or more prefill chunks. This method returns + all prompt hidden states at once and then forgets them. + Ensures correct RequestOutputKind.DELTA semantics + wherein all prompt hidden states are returned at once at + the end of prefill. + + Returns: + None if prompt hidden states are disabled for this request. + List of all prompt hidden states, otherwise. + """ + plp = self.prompt_hidden_states + if plp: + self.prompt_hidden_states = None + return plp + + def update_from_output(self, output: EngineCoreOutput) -> None: + if output.prompt_hidden_states is not None: + self._set_prompt_hidden_states(output.prompt_hidden_states) diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 7130f666ef19..b3da15229871 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -242,6 +242,7 @@ def step(self) -> Union[list[RequestOutput], list[PoolingRequestOutput]]: # 2) Process EngineCoreOutputs. iteration_stats = IterationStats() if self.log_stats else None + print("lxy call process_outputs") processed_outputs = self.output_processor.process_outputs( outputs.outputs, engine_core_timestamp=outputs.timestamp, diff --git a/vllm/v1/engine/output_processor.py b/vllm/v1/engine/output_processor.py index 2ee55b585da6..e5ac028859d0 100644 --- a/vllm/v1/engine/output_processor.py +++ b/vllm/v1/engine/output_processor.py @@ -15,6 +15,7 @@ from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine.detokenizer import IncrementalDetokenizer +from vllm.v1.engine.hidden_states import HiddenStatesProcessor from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.parallel_sampling import ParentRequest from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates, @@ -93,6 +94,7 @@ def __init__( arrival_time: float, queue: Optional[RequestOutputCollector], log_stats: bool, + hidden_states_processor: Optional[HiddenStatesProcessor], ): self.request_id = request_id self.parent_req = parent_req @@ -111,6 +113,7 @@ def __init__( self.stats = RequestStateStats( arrival_time=arrival_time) if log_stats else None + self.hidden_states_processor = hidden_states_processor @classmethod def from_new_request( @@ -137,10 +140,12 @@ def from_new_request( request=request, ) max_tokens_param = sampling_params.max_tokens + hidden_states_processor = HiddenStatesProcessor.from_new_request() else: logprobs_processor = None detokenizer = None max_tokens_param = None + hidden_states_processor = None assert request.pooling_params is not None output_kind = request.pooling_params.output_kind @@ -159,6 +164,7 @@ def from_new_request( arrival_time=request.arrival_time, queue=queue, log_stats=log_stats, + hidden_states_processor=hidden_states_processor, ) def make_request_output( @@ -204,7 +210,7 @@ def _new_request_output( finished: bool, kv_transfer_params: Optional[dict[str, Any]] = None, ) -> Union[RequestOutput, PoolingRequestOutput]: - + # Seeems here to process outputs first_output = outputs[0] if isinstance(first_output, PoolingOutput): assert len(outputs) == 1 @@ -215,17 +221,23 @@ def _new_request_output( finished=finished, ) assert self.logprobs_processor is not None + assert self.hidden_states_processor is not None if self.output_kind == RequestOutputKind.DELTA: # Side effect: logprobs processor forgets prompt logprobs prompt_logprobs = self.logprobs_processor.pop_prompt_logprobs() + prompt_hidden_states = self.hidden_states_processor.pop_prompt_hidden_states( + ) else: prompt_logprobs = self.logprobs_processor.prompt_logprobs + prompt_hidden_states = self.hidden_states_processor.prompt_hidden_states + # prompt logprobs is added here return RequestOutput( request_id=request_id, prompt=self.prompt, prompt_token_ids=self.prompt_token_ids, prompt_logprobs=prompt_logprobs, + prompt_hidden_states=prompt_hidden_states, outputs=cast(list[CompletionOutput], outputs), finished=finished, kv_transfer_params=kv_transfer_params, @@ -399,6 +411,7 @@ def process_outputs( kv_transfer_params = engine_core_output.kv_transfer_params req_state.num_cached_tokens = engine_core_output.num_cached_tokens req_state.is_prefilling = False + prompt_hidden_states = engine_core_output.prompt_hidden_states if pooling_output is None: assert req_state.detokenizer is not None @@ -414,8 +427,12 @@ def process_outputs( # if required. req_state.logprobs_processor.update_from_output( engine_core_output) + assert req_state.hidden_states_processor is not None + req_state.hidden_states_processor.update_from_output( + engine_core_output) # 4) Create and handle RequestOutput objects. + print("lxy here make_request_output", prompt_hidden_states is None) if request_output := req_state.make_request_output( new_token_ids, pooling_output, finish_reason, stop_reason, kv_transfer_params): diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index f8d6b24702f3..0447d0d14ecc 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -105,6 +105,9 @@ class ModelRunnerOutput: # [prompt_len] prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] + # req_id -> + prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]] + # [num_reqs, hidden_size] pooler_output: list[Optional[torch.Tensor]] @@ -128,5 +131,6 @@ class DraftTokenIds: sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, + prompt_hidden_states_dict={}, pooler_output=[], num_nans_in_logits=None) diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index ad70d9efaaaa..fabd2afc331d 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -217,10 +217,12 @@ def __init__( # NOTE(rob): num_prompt_logprobs only includes reqs # that are currently in the prefill phase. self.num_prompt_logprobs: dict[str, int] = {} - # To accumulate prompt logprobs tensor chunks across prefill steps. self.in_progress_prompt_logprobs_cpu: dict[str, LogprobsTensors] = {} + self.return_prompt_hidden_states_reqs: set[str] = set() + self.in_progress_prompt_hidden_states_cpu: dict[str, torch.Tensor] = {} + # Internal representation of per-step batch state changes, used for # reordering persistent batch and generating logitsprocs batch state # updates. Should reset each step. @@ -358,6 +360,9 @@ def add_request( self.num_prompt_logprobs[ req_id] = sampling_params.prompt_logprobs + if sampling_params.return_prompt_hidden_states: + self.return_prompt_hidden_states_reqs.add(req_id) + if sampling_params.allowed_token_ids: self.has_allowed_token_ids.add(req_id) if self.allowed_token_ids_mask_cpu_tensor is None: @@ -447,6 +452,7 @@ def remove_request(self, req_id: str) -> Optional[int]: self.num_logprobs.pop(req_id, None) self.num_prompt_logprobs.pop(req_id, None) self.in_progress_prompt_logprobs_cpu.pop(req_id, None) + self.in_progress_prompt_hidden_states_cpu.pop(req_id, None) self.has_allowed_token_ids.discard(req_id) if self.allowed_token_ids_mask_cpu_tensor is not None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 4556a51b809d..eac68b29370c 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1448,6 +1448,7 @@ def _pool( sampled_token_ids=[], logprobs=None, prompt_logprobs_dict={}, + prompt_hidden_states_dict={}, pooler_output=pooler_output, kv_connector_output=kv_connector_output, ) @@ -1683,6 +1684,10 @@ def execute_model( hidden_states[:num_scheduled_tokens], scheduler_output.num_scheduled_tokens, ) + prompt_hidden_states_dict = self._get_prompt_hidden_states_dict( + hidden_states[:num_scheduled_tokens], + scheduler_output.num_scheduled_tokens, + ) # Get the valid generated tokens. sampled_token_ids = sampler_output.sampled_token_ids @@ -1746,6 +1751,7 @@ def execute_model( sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + prompt_hidden_states_dict=prompt_hidden_states_dict, pooler_output=[], kv_connector_output=kv_connector_output, num_nans_in_logits=num_nans_in_logits, @@ -2123,6 +2129,90 @@ def _get_prompt_logprobs_dict( return prompt_logprobs_dict + def _get_prompt_hidden_states_dict( + self, + hidden_states: torch.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, Optional[torch.Tensor]]: + """ + This function is similar to _get_prompt_logprobs_dict but for prompt hidden states + """ + + return_prompt_hidden_states_reqs = self.input_batch.return_prompt_hidden_states_reqs + if not return_prompt_hidden_states_reqs: + return {} + + in_progress_dict = self.input_batch.in_progress_prompt_hidden_states_cpu + prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]] = {} + + # Since prompt hidden states are a rare feature, prioritize simple, + # maintainable loop over optimal performance. + completed_prefill_reqs = [] + for req_id in return_prompt_hidden_states_reqs: + num_tokens = num_scheduled_tokens[req_id] + + # Get metadata for this request. + request = self.requests[req_id] + num_prompt_tokens = len(request.prompt_token_ids) + + # Set up target hidden_states_tensors object. + hidden_states_tensors = in_progress_dict.get(req_id) + if not hidden_states_tensors: + # Create empty hidden_states_tensors CPU tensors for the entire prompt. + # If chunked, we'll copy in slice by slice. + hidden_states_tensors = torch.empty( + (num_prompt_tokens - 1, self.hidden_size), + dtype=torch.int32, + device="cpu") + in_progress_dict[req_id] = hidden_states_tensors + + # Determine number of hidden states to retrieve. + start_idx = request.num_computed_tokens + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(req_id) + prompt_hidden_states_dict[req_id] = hidden_states_tensors + + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt hidden states to produce. + continue + + # Get the hidden states corresponding to this req's prompt tokens. + # If this is a partial request (i.e. chunked prefill), + # then there is prompt hidden states generated for each index. + req_idx = self.input_batch.req_id_to_index[req_id] + offset = self.query_start_loc.np[req_idx].item() + prompt_hidden_states = hidden_states[offset:offset + num_logits] + + # Transfer GPU->CPU async. + chunk_slice = slice(start_idx, start_idx + num_logits) + hidden_states_tensors[chunk_slice].copy_(prompt_hidden_states, + non_blocking=True) + + # Remove requests that have completed prefill from the batch + # num_prompt_logprobs_dict. + for req_id in completed_prefill_reqs: + return_prompt_hidden_states_reqs.remove(req_id) + del in_progress_dict[req_id] + + # Must synchronize the non-blocking GPU->CPU transfers. + if prompt_hidden_states_dict: + self._sync_device() + + # the return would be empty for prior steps + return prompt_hidden_states_dict + def _get_nans_in_logits( self, logits: Optional[torch.Tensor], diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 985d5ba58c49..c58ec588b1ee 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1072,8 +1072,10 @@ def concat_lists(input_lists): req_ids = cast(list[str], self.input_batch.req_ids[:num_reqs]) prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + prompt_hidden_states_dict: dict[str, Optional[torch.Tensor]] = {} for req_id in self.input_batch.req_ids[:num_reqs]: prompt_logprobs_dict[req_id] = None + prompt_hidden_states_dict[req_id] = None max_gen_len = selected_token_ids.shape[-1] if max_gen_len == 1: @@ -1119,6 +1121,7 @@ def concat_lists(input_lists): sampled_token_ids=valid_sampled_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, + prompt_hidden_states_dict=prompt_hidden_states_dict, pooler_output=[], kv_connector_output=kv_connector_output, )