55from dataclasses import InitVar , field
66from typing import Any , Literal
77
8- from pydantic import Field , SkipValidation , model_validator
8+ from pydantic import SkipValidation , model_validator
99from pydantic .dataclasses import dataclass
1010from typing_extensions import Self
1111
@@ -37,10 +37,10 @@ class SchedulerConfig:
3737 This config has no static default. If left unspecified by the user, it will
3838 be set in `EngineArgs.create_engine_config` based on the usage context."""
3939
40- prefill_max_num_batched_tokens : int = Field ( init = False )
41- """Prefill maximum number of tokens to be processed in a single iteration.
42-
43- This config is used when there are no decoding requests ."""
40+ prefill_max_num_batched_tokens : int | None = None
41+ """Maximum number of tokens to be processed in a single iteration when there
42+ are no decode requests. If not set (None), defaults to max_num_batched_tokens.
43+ Must satisfy: prefill_max_num_batched_tokens >= max_num_batched_tokens ."""
4444
4545 max_num_seqs : SkipValidation [int ] = None # type: ignore
4646 """Maximum number of sequences to be processed in a single iteration.
@@ -80,11 +80,6 @@ class SchedulerConfig:
8080 """If True, prefill requests can be chunked based
8181 on the remaining max_num_batched_tokens."""
8282
83- enable_hybrid_chunked_prefill : bool = False
84- """If True, prefill requests will only be chunked when there are decode
85- requests present, otherwise they will proceed with normal prefill
86- computation to increase throughput."""
87-
8883 is_multimodal_model : bool = False
8984 """True if the model is multimodal."""
9085
@@ -183,9 +178,6 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
183178 " prefix caching; disabling both."
184179 )
185180
186- self .prefill_max_num_batched_tokens = max (
187- self .max_model_len , DEFAULT_MAX_NUM_BATCHED_TOKENS
188- )
189181 if self .max_num_batched_tokens is None :
190182 if self .enable_chunked_prefill :
191183 self .max_num_batched_tokens = DEFAULT_MAX_NUM_BATCHED_TOKENS
@@ -203,30 +195,23 @@ def __post_init__(self, is_encoder_decoder: bool) -> None:
203195 self .max_num_batched_tokens ,
204196 POOLING_MODEL_MAX_NUM_BATCHED_TOKENS ,
205197 )
206- self .prefill_max_num_batched_tokens = max (
207- self .prefill_max_num_batched_tokens ,
208- POOLING_MODEL_MAX_NUM_BATCHED_TOKENS ,
209- )
210198 if self .is_multimodal_model :
211199 # The value needs to be at least the number of multimodal tokens
212200 self .max_num_batched_tokens = max (
213201 self .max_num_batched_tokens ,
214202 MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS ,
215203 )
216- self .prefill_max_num_batched_tokens = max (
217- self .prefill_max_num_batched_tokens ,
218- MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS ,
219- )
220204 # When using default settings,
221205 # Ensure max_num_batched_tokens does not exceed model limit.
222206 # Some models (e.g., Whisper) have embeddings tied to max length.
223207 self .max_num_batched_tokens = min (
224208 self .max_num_seqs * self .max_model_len , self .max_num_batched_tokens
225209 )
226- self .prefill_max_num_batched_tokens = min (
227- self .max_num_seqs * self .max_model_len ,
228- self .prefill_max_num_batched_tokens ,
229- )
210+
211+ # Initialize prefill_max_num_batched_tokens based on user input
212+ if self .prefill_max_num_batched_tokens is None :
213+ # Default to max_num_batched_tokens
214+ self .prefill_max_num_batched_tokens = self .max_num_batched_tokens
230215 self .max_num_encoder_input_tokens = self .max_num_batched_tokens
231216 self .encoder_cache_size = self .max_num_batched_tokens
232217
@@ -318,12 +303,13 @@ def _verify_args(self) -> Self:
318303 f"max_num_partial_prefills ({ self .max_num_partial_prefills } )."
319304 )
320305
321- if self .enable_hybrid_chunked_prefill and not self .chunked_prefill_enabled :
306+ # Validate prefill_max_num_batched_tokens
307+ if self .prefill_max_num_batched_tokens < self .max_num_batched_tokens :
322308 raise ValueError (
323- "Hybrid chunked prefill can only be enabled when chunked "
324- "prefill is enabled. Please set --enable-chunked-prefill=True "
325- " or disable hybrid chunked prefill by setting "
326- "--enable-hybrid-chunked-prefill=False ."
309+ f"prefill_max_num_batched_tokens "
310+ f"( { self . prefill_max_num_batched_tokens } ) must be greater "
311+ f"than or equal to max_num_batched_tokens "
312+ f"( { self . max_num_batched_tokens } ) ."
327313 )
328314
329315 return self
0 commit comments