Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions vllm_gaudi/v1/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3007,19 +3007,24 @@ def execute_model(

self.event_start = self.profiler.get_timestamp_us()
self.profiler.start("internal", "prefill")
# Align behavior of incomplete prompt with gpu_model_runner
# If logits_indices is smaller than req_id,
# add the last token position
# NOTE(tianmu-li): Align behavior of incomplete prompt with gpu_model_runner
# If logits_indices is smaller than req_id, the last request is a chunked prompt request that
# hasn't finished in this step. We add the last token position to logits_indices to ensure
# the last token of the chunk is sampled. This sampled token will be discarded later
if logits_indices.shape[0] < len(req_id):
if structured_output:
logits_append = torch.tensor([torch.sum(prompt_len) - 1],
device=token_ids.device,
dtype=torch.int32)
logits_indices = torch.cat([logits_indices, logits_append])
elif self.use_async_scheduling:
# Discard partial prefill logits for async scheduling
if structured_output or self.use_async_scheduling:
# When there are multiple requests in the batch (e.g. self.use_merged_prefill=True),
# the last token position is the sum of all prompt lengths - 1
# This logic also holds when there is only one request in the batch
logits_indices_append = torch.tensor([torch.sum(prompt_len) - 1],
device=token_ids.device,
dtype=torch.int32)
logits_indices = torch.cat([logits_indices, logits_indices_append])
if self.use_async_scheduling:
# Discard partial prefill logit for async scheduling
# Depends on 1 decode token/batch
invalid_req_indices.append(num_decodes + idx)
prefill_start_idx = num_decodes
invalid_req_indices.append(prefill_start_idx + idx)
htorch.core.mark_step()
non_flattened_hidden_states, aux_hidden_states, \
sample_hidden_states, logits_device = \
Expand Down Expand Up @@ -3321,7 +3326,7 @@ def execute_model(
return AsyncHPUModelRunnerOutput(
model_runner_output=model_runner_output,
sampled_token_ids=sampled_token_ids,
invalid_req_indices=[],
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)
model_runner_output = ModelRunnerOutput(
Expand Down