Skip to content

Commit 1c11f68

Browse files
fix: unbatch removals of requests from input_batch (#511)
# Description Calls to `input_batch.remove_request` adjust the internal state of the batch and change the mapping from request ids to indices. If multiple removals are batched, they can cause the `batch_update` to have repeat indices, which leads to bugs. The change here forces input batch metadata updates to happen for each removal one-at-a-time; completely unbatching the batch_updates. ## Related Issues Fixes #508 See also #492 (comment) --------- Signed-off-by: Travis Johnson <[email protected]>
1 parent 0c9b971 commit 1c11f68

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

tests/v1/worker/test_spyre_input_batch.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@ def _remove_requests(input_batch: SamplingInputBatch, batch_size: int,
3636
for index in req_indices_to_remove:
3737
input_batch.remove_request(reqs[index].req_id)
3838
req_ids_to_remove.add(reqs[index].req_id)
39+
40+
# FIXME: it is a bug in the current implementation that removed indices may
41+
# be duplicated, which can break logitsprocs tracking. Once fixed we should
42+
# add this assert.
43+
# see also: https://github.com/vllm-project/vllm-spyre/issues/508
44+
# removed = input_batch.batch_update_builder.removed
45+
# assert len(set(removed)) == len(removed), "Duplicate removed indices"
46+
3947
return req_ids_to_remove
4048

4149

vllm_spyre/v1/worker/spyre_model_runner.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ def update_states(self, scheduler_output: SchedulerOutput):
394394
for req_id in scheduler_output.finished_req_ids:
395395
self.input_batch.remove_request(req_id)
396396
self.requests.pop(req_id, None)
397-
self.input_batch.refresh_metadata()
397+
# TODO: Processing multiple removals at once can break alignment
398+
# of logitprocs. Refactor so that we can batch removals to the
399+
# `input_batch`
400+
self.input_batch.refresh_metadata()
398401

399402
def _get_prompt_logprobs_dict(
400403
self,

0 commit comments

Comments
 (0)