77import vllm .envs as envs
88from vllm .logger import init_logger
99from vllm .model_executor .models import ModelRegistry
10- from vllm .utils .math_utils import cdiv , next_power_of_2 , round_up
10+ from vllm .utils .math_utils import cdiv , round_up
1111from vllm .utils .torch_utils import STR_DTYPE_TO_TORCH_DTYPE
12+ from vllm .v1 .attention .backends .flashinfer import FlashInferBackend
1213from vllm .v1 .kv_cache_interface import FullAttentionSpec , MambaSpec , MLAAttentionSpec
1314
1415if TYPE_CHECKING :
@@ -364,6 +365,10 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
364365 ).page_size_bytes
365366 else :
366367 kernel_block_alignment_size = 16
368+ if envs .VLLM_ATTENTION_BACKEND == "FLASHINFER" :
369+ kernel_block_alignment_size = min (
370+ FlashInferBackend .get_supported_kernel_block_size ()
371+ )
367372 attn_page_size_1_token = FullAttentionSpec (
368373 block_size = 1 ,
369374 num_kv_heads = model_config .get_num_kv_heads (parallel_config ),
@@ -410,7 +415,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
410415 attn_tokens_per_mamba_state = cdiv (mamba_page_size , attn_page_size_1_token )
411416 chunk_size = lcm (base_chunk_size , kernel_block_alignment_size )
412417 attn_block_size = chunk_size * cdiv (attn_tokens_per_mamba_state , chunk_size )
413- attn_block_size = next_power_of_2 (attn_block_size )
414418 cache_config .mamba_block_size = attn_block_size
415419 else :
416420 # Without prefix caching, select minimum valid attention block size
@@ -422,7 +426,6 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None:
422426 attn_block_size = kernel_block_alignment_size * cdiv (
423427 mamba_page_size , kernel_block_alignment_size * attn_page_size_1_token
424428 )
425- attn_block_size = next_power_of_2 (attn_block_size )
426429
427430 # override attention block size if either (a) the
428431 # user has not set it or (b) the user has set it
0 commit comments