diff --git a/vllm_gaudi/v1/worker/hpu_model_runner.py b/vllm_gaudi/v1/worker/hpu_model_runner.py index 4049b09a..ddb64fee 100644 --- a/vllm_gaudi/v1/worker/hpu_model_runner.py +++ b/vllm_gaudi/v1/worker/hpu_model_runner.py @@ -3007,19 +3007,24 @@ def execute_model( self.event_start = self.profiler.get_timestamp_us() self.profiler.start("internal", "prefill") - # Align behavior of incomplete prompt with gpu_model_runner - # If logits_indices is smaller than req_id, - # add the last token position + # NOTE(tianmu-li): Align behavior of incomplete prompt with gpu_model_runner + # If logits_indices is smaller than req_id, the last request is a chunked prompt request that + # hasn't finished in this step. We add the last token position to logits_indices to ensure + # the last token of the chunk is sampled. This sampled token will be discarded later if logits_indices.shape[0] < len(req_id): - if structured_output: - logits_append = torch.tensor([torch.sum(prompt_len) - 1], - device=token_ids.device, - dtype=torch.int32) - logits_indices = torch.cat([logits_indices, logits_append]) - elif self.use_async_scheduling: - # Discard partial prefill logits for async scheduling + if structured_output or self.use_async_scheduling: + # When there are multiple requests in the batch (e.g. self.use_merged_prefill=True), + # the last token position is the sum of all prompt lengths - 1 + # This logic also holds when there is only one request in the batch + logits_indices_append = torch.tensor([torch.sum(prompt_len) - 1], + device=token_ids.device, + dtype=torch.int32) + logits_indices = torch.cat([logits_indices, logits_indices_append]) + if self.use_async_scheduling: + # Discard partial prefill logit for async scheduling # Depends on 1 decode token/batch - invalid_req_indices.append(num_decodes + idx) + prefill_start_idx = num_decodes + invalid_req_indices.append(prefill_start_idx + idx) htorch.core.mark_step() non_flattened_hidden_states, aux_hidden_states, \ sample_hidden_states, logits_device = \ @@ -3321,7 +3326,7 @@ def execute_model( return AsyncHPUModelRunnerOutput( model_runner_output=model_runner_output, sampled_token_ids=sampled_token_ids, - invalid_req_indices=[], + invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, ) model_runner_output = ModelRunnerOutput(