Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions vllm_spyre/v1/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,13 +626,13 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
"""Last chunked prefill can be scheduled only if there is enough space
in the decode batch, and if all the other spyre-related conditions
are satisfied."""

decoding_requests = [
r for r in self.running if r not in self.ongoing_prefills
]
max_context_len = self.scheduler_config.max_model_len

# check that there is space in the current decode batch
num_running = len(self.running)
if request in self.running:
num_running -= 1
num_running = len(decoding_requests)
cond1 = num_running + len(self.waiting) < self.max_num_running_reqs

# calculate new max tkv of the batch given the new sequence joins
Expand All @@ -649,7 +649,7 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
# note that the -1 comes from the token we generate during prefill
cond2 = request.max_tokens - 1 <= (max_context_len - new_req_tkv)
# check cond2 for all other sequences in the current decode batch
for req in self.running:
for req in decoding_requests:
# current tkv of the (left aligned) decode sequence
dec_req_tkv = n_blocks * self.block_size + \
req.num_computed_tokens % self.block_size
Expand All @@ -667,12 +667,12 @@ def _satisfies_last_chunk_constraints(self, request: Request) -> bool:
# check that batch size x tkv is smaller than the max supported number
# Note: using max_tkv is a conservative upper bound here. For the
# optimal check we need model runner to return per sequence tkvs
cond3 = lambda: self.check_batch_tkv_limit_cp(request=request,
new_req_tkv=new_req_tkv,
n_blocks=n_blocks,
running=self.running,
max_batch_tkv_limit=self.
max_batch_tkv_limit)
cond3 = lambda: self.check_batch_tkv_limit_cp(
request=request,
new_req_tkv=new_req_tkv,
n_blocks=n_blocks,
running=decoding_requests,
max_batch_tkv_limit=self.max_batch_tkv_limit)

return cond1 and cond2 and cond3()

Expand Down