Skip to content

Commit 59c7c62

Browse files
committed
suppression warning
1 parent e8771be commit 59c7c62

File tree

5 files changed

+36
-8
lines changed

5 files changed

+36
-8
lines changed

lmdeploy/pytorch/backends/cuda/attention.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22

3+
import functools
34
from dataclasses import dataclass
45
from typing import Literal
56

@@ -20,8 +21,8 @@
2021
assert torch.ops.flash_attn_3 is not None
2122
use_fa3 = True
2223
except Exception:
23-
logger.warning('For higher performance, please install FlashAttention-3 '
24-
'https://github.com/Dao-AILab/flash-attention')
24+
logger.debug('For higher performance, please install FlashAttention-3 '
25+
'https://github.com/Dao-AILab/flash-attention')
2526

2627

2728
@dataclass
@@ -221,6 +222,15 @@ def forward(
221222
return attn_output
222223

223224

225+
@functools.lru_cache
226+
def use_fa3_warning():
227+
if use_fa3:
228+
return True
229+
logger.warning('For higher performance, please install FlashAttention-3 '
230+
'https://github.com/Dao-AILab/flash-attention')
231+
return False
232+
233+
224234
class FlashMLAImpl(TritonAttentionImpl):
225235

226236
def __init__(
@@ -255,6 +265,7 @@ def __init__(
255265
from lmdeploy.pytorch.kernels.cuda import flash_mla_fwd
256266
self.flash_mla_fwd = flash_mla_fwd
257267
assert num_kv_heads == 1, 'MLA requires num kv heads equal to 1'
268+
use_fa3_warning()
258269

259270
def forward(
260271
self,
@@ -515,6 +526,14 @@ def forward(
515526
return attn_output
516527

517528

529+
@functools.lru_cache
530+
def _enable_fa3(alibi: bool, learnable_sink: bool, block_sparse_size: int):
531+
enable = not alibi and not learnable_sink and block_sparse_size == 1
532+
if enable and not use_fa3_warning():
533+
enable = False
534+
return enable
535+
536+
518537
class TritonAttentionBuilder(AttentionBuilder[TritonAttentionMetadata]):
519538
"""Triton attention builder."""
520539

@@ -535,8 +554,9 @@ def build(
535554
**kwargs,
536555
) -> TritonAttentionImpl:
537556
"""build."""
538-
enable_fa3 = use_fa3 and not alibi and not learnable_sink and block_sparse_size == 1
557+
enable_fa3 = _enable_fa3(alibi, learnable_sink, block_sparse_size)
539558
if use_flash_mla is True:
559+
logger.debug('Build FlashMLAImpl Attention')
540560
return FlashMLAImpl(num_heads,
541561
head_size,
542562
scale=scale,
@@ -548,6 +568,7 @@ def build(
548568
causal=causal,
549569
**kwargs)
550570
elif enable_fa3:
571+
logger.debug('Build FA3Impl Attention')
551572
return FA3Impl(num_heads,
552573
head_size,
553574
scale=scale,
@@ -559,6 +580,7 @@ def build(
559580
causal=causal,
560581
**kwargs)
561582
else:
583+
logger.debug('Build TritonAttentionImpl Attention')
562584
return TritonAttentionImpl(num_heads,
563585
head_size,
564586
scale=scale,

lmdeploy/pytorch/check_env/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from .base import BaseChecker
55

66
MIN_TRANSFORMERS_VERSION = '4.33.0'
7-
MAX_TRANSFORMERS_VERSION = '4.53.3'
7+
MAX_TRANSFORMERS_VERSION = '4.56.1'
88

99

1010
class TransformersChecker(BaseChecker):

lmdeploy/pytorch/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str):
2828
config.dtype = torch.float16
2929
return config
3030

31-
torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
31+
torch_dtype = getattr(config.hf_config, 'dtype', None)
32+
if torch_dtype is None:
33+
torch_dtype = getattr(config.hf_config, 'torch_dtype', None)
34+
3235
# deal with case when torch_dtype is not string but torch.dtype
3336
if isinstance(torch_dtype, torch.dtype):
3437
torch_dtype = str(torch_dtype).split('.')[1]

lmdeploy/pytorch/disagg/backend/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
logger.debug('Registering DLSlime Backend')
88
from .dlslime import DLSlimeBackend
99
except ImportError:
10-
logger.warning('Disable DLSlime Backend')
10+
logger.debug('Disable DLSlime Backend')
1111

1212
try:
1313
logger.debug('Registering Mooncake Backend')

lmdeploy/tokenizer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,8 +423,11 @@ class Tokenizer:
423423
"""
424424

425425
def __init__(self, model_path: str):
426-
from transformers import PretrainedConfig
427-
model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
426+
from transformers import AutoConfig, PretrainedConfig
427+
try:
428+
model_cfg = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
429+
except BaseException:
430+
model_cfg = PretrainedConfig.from_pretrained(model_path, trust_remote_code=True)
428431
is_gpt_oss = getattr(model_cfg, 'model_type', '') == 'gpt_oss'
429432
from transformers.models.auto.tokenization_auto import get_tokenizer_config
430433
tokenizer_config = get_tokenizer_config(model_path, trust_remote_code=True)

0 commit comments

Comments
 (0)