-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[BUGFIX] Adjust kv block sizes #27704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ac1020f
3c284e2
2221dd1
3a51814
6e4d374
ecd76fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| from math import lcm | ||
| from typing import TYPE_CHECKING | ||
|
|
||
| from vllm.attention.backends.abstract import MultipleOf | ||
| from vllm.attention.selector import get_attn_backend | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.models.config import VerifyAndUpdateConfig | ||
|
|
||
| if TYPE_CHECKING: | ||
| from vllm.config import VllmConfig | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
||
| class AttentionConfig(VerifyAndUpdateConfig): | ||
| @classmethod | ||
| def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: | ||
| """ | ||
| Align cache_config.block_size with attention backend's minimum | ||
| supported kernel block size. | ||
|
|
||
| Args: | ||
| vllm_config: vLLM Config | ||
| """ | ||
| model_config = vllm_config.model_config | ||
| cache_config = vllm_config.cache_config | ||
| assert cache_config is not None | ||
|
|
||
| backend_cls = get_attn_backend( | ||
| head_size=model_config.get_head_size(), | ||
| dtype=model_config.dtype, | ||
| kv_cache_dtype=cache_config.cache_dtype, | ||
| block_size=cache_config.block_size | ||
| if cache_config.block_size is not None | ||
| else 16, | ||
| ) | ||
|
|
||
| supported_sizes = backend_cls.get_supported_kernel_block_size() | ||
| supported_sizes = [ | ||
| s.base if isinstance(s, MultipleOf) else s for s in supported_sizes | ||
| ] | ||
| min_size = min(supported_sizes) | ||
| if cache_config.block_size is None: | ||
| new_block_size = min_size | ||
| else: | ||
| new_block_size = lcm(cache_config.block_size, min_size) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Prefer to raise an error if the user sets a block_size but the block_size is not supported by the attention backend it selects. |
||
|
|
||
| if cache_config.block_size is None or new_block_size != cache_config.block_size: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't need to add info-level logging if block_size is None and is initialized normally. |
||
| cache_config.block_size = new_block_size | ||
| logger.info( | ||
| "Setting attention block size to %d tokens " | ||
| "to align with %s attention backend's supported kernel block sizes.", | ||
| new_block_size, | ||
| backend_cls.get_name(), | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -356,7 +356,11 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: | |
| dtype=kv_cache_dtype, | ||
| ).page_size_bytes | ||
| else: | ||
| kernel_block_alignment_size = 16 | ||
| if cache_config.block_size is not None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this part called before or after AttentionConfig.verify_and_update_config(self)? I think the kernel_block_alignment_size should be resolved from backend_cls.get_supported_kernel_block_size |
||
| kernel_block_alignment_size = cache_config.block_size | ||
| else: | ||
| kernel_block_alignment_size = 16 | ||
|
|
||
| if ( | ||
| current_platform.is_device_capability(100) | ||
| and model_config.get_head_size() == 256 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -173,7 +173,10 @@ def get_supported_kernel_block_size() -> list[int | MultipleOf]: | |
| # Note: Not sure for all platforms, | ||
| # but on Blackwell, only support a page size of | ||
| # 16, 32, 64 | ||
| return [16, 32, 64] | ||
| # TODO: 16 is temporary removed because TRT-LLM kernel has a bug when using 16. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if the problem only exist in trtllm, what about only remove 16 for trtllm like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed, we should only override on Blackwell |
||
| # See https://github.com/flashinfer-ai/flashinfer/issues/1993 | ||
| # for more details. | ||
| return [32, 64] | ||
|
|
||
| @classmethod | ||
| def validate_head_size(cls, head_size: int) -> None: | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.