Skip to content

Commit 59947ac

Browse files
committed
prefill_max_num_batched_tokens optimization
Signed-off-by: Ther-LF <[email protected]>
1 parent 22b54d6 commit 59947ac

File tree

4 files changed

+41
-50
lines changed

4 files changed

+41
-50
lines changed

vllm/config/scheduler.py

Lines changed: 16 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from dataclasses import InitVar, field
66
from typing import Any, Literal
77

8-
from pydantic import Field, SkipValidation, model_validator
8+
from pydantic import SkipValidation, model_validator
99
from pydantic.dataclasses import dataclass
1010
from typing_extensions import Self
1111

@@ -37,10 +37,10 @@ class SchedulerConfig:
3737
This config has no static default. If left unspecified by the user, it will
3838
be set in `EngineArgs.create_engine_config` based on the usage context."""
3939

40-
prefill_max_num_batched_tokens: int = Field(init=False)
41-
"""Prefill maximum number of tokens to be processed in a single iteration.
42-
43-
This config is used when there are no decoding requests."""
40+
prefill_max_num_batched_tokens: int | None = None
41+
"""Maximum number of tokens to be processed in a single iteration when there
42+
are no decode requests. If not set (None), defaults to max_num_batched_tokens.
43+
Must satisfy: prefill_max_num_batched_tokens >= max_num_batched_tokens."""
4444

4545
max_num_seqs: SkipValidation[int] = None # type: ignore
4646
"""Maximum number of sequences to be processed in a single iteration.
@@ -80,11 +80,6 @@ class SchedulerConfig:
8080
"""If True, prefill requests can be chunked based
8181
on the remaining max_num_batched_tokens."""
8282

83-
enable_hybrid_chunked_prefill: bool = False
84-
"""If True, prefill requests will only be chunked when there are decode
85-
requests present, otherwise they will proceed with normal prefill
86-
computation to increase throughput."""
87-
8883
is_multimodal_model: bool = False
8984
"""True if the model is multimodal."""
9085

@@ -183,9 +178,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
183178
" prefix caching; disabling both."
184179
)
185180

186-
self.prefill_max_num_batched_tokens = max(
187-
self.max_model_len, DEFAULT_MAX_NUM_BATCHED_TOKENS
188-
)
189181
if self.max_num_batched_tokens is None:
190182
if self.enable_chunked_prefill:
191183
self.max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
@@ -203,30 +195,23 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
203195
self.max_num_batched_tokens,
204196
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
205197
)
206-
self.prefill_max_num_batched_tokens = max(
207-
self.prefill_max_num_batched_tokens,
208-
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS,
209-
)
210198
if self.is_multimodal_model:
211199
# The value needs to be at least the number of multimodal tokens
212200
self.max_num_batched_tokens = max(
213201
self.max_num_batched_tokens,
214202
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
215203
)
216-
self.prefill_max_num_batched_tokens = max(
217-
self.prefill_max_num_batched_tokens,
218-
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
219-
)
220204
# When using default settings,
221205
# Ensure max_num_batched_tokens does not exceed model limit.
222206
# Some models (e.g., Whisper) have embeddings tied to max length.
223207
self.max_num_batched_tokens = min(
224208
self.max_num_seqs * self.max_model_len, self.max_num_batched_tokens
225209
)
226-
self.prefill_max_num_batched_tokens = min(
227-
self.max_num_seqs * self.max_model_len,
228-
self.prefill_max_num_batched_tokens,
229-
)
210+
211+
# Initialize prefill_max_num_batched_tokens based on user input
212+
if self.prefill_max_num_batched_tokens is None:
213+
# Default to max_num_batched_tokens
214+
self.prefill_max_num_batched_tokens = self.max_num_batched_tokens
230215
self.max_num_encoder_input_tokens = self.max_num_batched_tokens
231216
self.encoder_cache_size = self.max_num_batched_tokens
232217

@@ -318,12 +303,13 @@ def _verify_args(self) -> Self:
318303
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
319304
)
320305

321-
if self.enable_hybrid_chunked_prefill and not self.chunked_prefill_enabled:
306+
# Validate prefill_max_num_batched_tokens
307+
if self.prefill_max_num_batched_tokens < self.max_num_batched_tokens:
322308
raise ValueError(
323-
"Hybrid chunked prefill can only be enabled when chunked "
324-
"prefill is enabled. Please set --enable-chunked-prefill=True "
325-
"or disable hybrid chunked prefill by setting "
326-
"--enable-hybrid-chunked-prefill=False."
309+
f"prefill_max_num_batched_tokens "
310+
f"({self.prefill_max_num_batched_tokens}) must be greater "
311+
f"than or equal to max_num_batched_tokens "
312+
f"({self.max_num_batched_tokens})."
327313
)
328314

329315
return self

vllm/engine/arg_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -424,6 +424,9 @@ class EngineArgs:
424424
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
425425
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
426426
max_num_batched_tokens: int | None = SchedulerConfig.max_num_batched_tokens
427+
prefill_max_num_batched_tokens: int | None = (
428+
SchedulerConfig.prefill_max_num_batched_tokens
429+
)
427430
max_num_partial_prefills: int = SchedulerConfig.max_num_partial_prefills
428431
max_long_partial_prefills: int = SchedulerConfig.max_long_partial_prefills
429432
long_prefill_token_threshold: int = SchedulerConfig.long_prefill_token_threshold
@@ -483,7 +486,6 @@ class EngineArgs:
483486
ignore_patterns: str | list[str] = get_field(LoadConfig, "ignore_patterns")
484487

485488
enable_chunked_prefill: bool | None = SchedulerConfig.enable_chunked_prefill
486-
enable_hybrid_chunked_prefill: bool = SchedulerConfig.enable_hybrid_chunked_prefill
487489
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
488490

489491
disable_hybrid_kv_cache_manager: bool = (
@@ -1005,6 +1007,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10051007
scheduler_group.add_argument(
10061008
"--max-num-batched-tokens", **scheduler_kwargs["max_num_batched_tokens"]
10071009
)
1010+
scheduler_group.add_argument(
1011+
"--prefill-max-num-batched-tokens",
1012+
**scheduler_kwargs["prefill_max_num_batched_tokens"],
1013+
)
10081014
scheduler_group.add_argument(
10091015
"--max-num-seqs", **scheduler_kwargs["max_num_seqs"]
10101016
)
@@ -1030,10 +1036,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
10301036
scheduler_group.add_argument(
10311037
"--enable-chunked-prefill", **scheduler_kwargs["enable_chunked_prefill"]
10321038
)
1033-
scheduler_group.add_argument(
1034-
"--enable-hybrid-chunked-prefill",
1035-
**scheduler_kwargs["enable_hybrid_chunked_prefill"],
1036-
)
10371039
scheduler_group.add_argument(
10381040
"--disable-chunked-mm-input", **scheduler_kwargs["disable_chunked_mm_input"]
10391041
)
@@ -1578,11 +1580,11 @@ def create_engine_config(
15781580
scheduler_config = SchedulerConfig(
15791581
runner_type=model_config.runner_type,
15801582
max_num_batched_tokens=self.max_num_batched_tokens,
1583+
prefill_max_num_batched_tokens=self.prefill_max_num_batched_tokens,
15811584
max_num_seqs=self.max_num_seqs,
15821585
max_model_len=model_config.max_model_len,
15831586
num_lookahead_slots=num_lookahead_slots,
15841587
enable_chunked_prefill=self.enable_chunked_prefill,
1585-
enable_hybrid_chunked_prefill=self.enable_hybrid_chunked_prefill,
15861588
disable_chunked_mm_input=self.disable_chunked_mm_input,
15871589
is_multimodal_model=model_config.is_multimodal_model,
15881590
is_encoder_decoder=model_config.is_encoder_decoder,

vllm/v1/core/sched/scheduler.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,15 @@ def schedule(self) -> SchedulerOutput:
224224
num_scheduled_tokens: dict[str, int] = {}
225225

226226
token_budget = self.max_num_scheduled_tokens
227-
# Check if there are any requests in the decode phase in the running queue
228-
# when hybrid chunked prefill is enabled.
229-
has_decode_requests = True
230-
if self.scheduler_config.enable_hybrid_chunked_prefill:
231-
has_decode_requests = self._has_decode_reqs
232-
if not has_decode_requests:
233-
token_budget = self.prefill_max_num_scheduled_tokens
227+
# Check if there are any requests in the decode phase in the running queue.
228+
# If no decode requests and prefill_max_num_batched_tokens is larger,
229+
# use the larger budget for better throughput.
230+
has_decode_requests = self._has_decode_reqs
231+
if (
232+
not has_decode_requests
233+
and self.prefill_max_num_scheduled_tokens > self.max_num_scheduled_tokens
234+
):
235+
token_budget = self.prefill_max_num_scheduled_tokens
234236

235237
# Encoder-related.
236238
scheduled_encoder_inputs: dict[str, list[int]] = {}
@@ -499,7 +501,6 @@ def schedule(self) -> SchedulerOutput:
499501
# pooling requests to be chunked
500502
if (
501503
not self.scheduler_config.chunked_prefill_enabled
502-
and not self.scheduler_config.enable_hybrid_chunked_prefill
503504
and num_new_tokens > token_budget
504505
):
505506
self.waiting.pop_request()
@@ -626,8 +627,8 @@ def schedule(self) -> SchedulerOutput:
626627
# Check if the scheduling constraints are satisfied.
627628
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
628629
if (
629-
self.scheduler_config.enable_hybrid_chunked_prefill
630-
and not has_decode_requests
630+
not has_decode_requests
631+
and self.prefill_max_num_scheduled_tokens > self.max_num_scheduled_tokens
631632
):
632633
assert total_num_scheduled_tokens <= self.prefill_max_num_scheduled_tokens
633634
else:

vllm/v1/worker/gpu_model_runner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,10 +249,12 @@ def __init__(
249249
self.is_multimodal_pruning_enabled = False
250250
self.max_model_len = model_config.max_model_len
251251
self.dcp_world_size = self.parallel_config.decode_context_parallel_size
252-
if scheduler_config.enable_hybrid_chunked_prefill:
253-
self.max_num_tokens = scheduler_config.prefill_max_num_batched_tokens
254-
else:
255-
self.max_num_tokens = scheduler_config.max_num_batched_tokens
252+
# Use the larger of max_num_batched_tokens and prefill_max_num_batched_tokens
253+
# for memory profiling to ensure we allocate enough memory
254+
self.max_num_tokens = max(
255+
scheduler_config.max_num_batched_tokens,
256+
scheduler_config.prefill_max_num_batched_tokens,
257+
)
256258
self.max_num_reqs = scheduler_config.max_num_seqs
257259

258260
# Broadcast PP output for external_launcher (torchrun)

0 commit comments

Comments
 (0)