Skip to content

Commit b99a5da

Browse files
authored
refactor PyTorchEngine check env (#2870)
* refactor checker * config builder * fix * fix * update triton * remove dockerfile update * update torch version
1 parent af7157a commit b99a5da

File tree

11 files changed

+474
-303
lines changed

11 files changed

+474
-303
lines changed
Lines changed: 5 additions & 268 deletions
Original file line numberDiff line numberDiff line change
@@ -1,277 +1,14 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
from logging import Logger
3-
from typing import List
4-
5-
from lmdeploy.utils import get_logger
6-
7-
8-
def _handle_exception(e: Exception,
9-
mod_name: str,
10-
logger: Logger,
11-
message: str = None):
12-
red_color = '\033[31m'
13-
reset_color = '\033[0m'
14-
if message is None:
15-
message = 'Please ensure it has been installed correctly.'
16-
logger.debug('Exception', exc_info=1)
17-
logger.error(f'{type(e).__name__}: {e}')
18-
logger.error(f'{red_color}'
19-
f'<{mod_name}> test failed!\n'
20-
f'{message}'
21-
f'{reset_color}')
22-
exit(1)
2+
from .base import BaseChecker # noqa: F401
233

244

255
def check_env_deeplink(device_type: str):
266
"""check Deeplink environment."""
27-
try_import_deeplink(device_type)
7+
from .deeplink import DeeplinkChecker
8+
checker = DeeplinkChecker(device_type)
9+
checker.handle()
2810

2911

3012
def try_import_deeplink(device_type: str):
31-
"""import dlinfer if specific device_type is set."""
32-
deeplink_device_type_list = [
33-
'ascend',
34-
'npu',
35-
'maca',
36-
]
37-
if device_type in deeplink_device_type_list:
38-
logger = get_logger('lmdeploy')
39-
try:
40-
import dlinfer.framework.lmdeploy_ext # noqa: F401
41-
except Exception as e:
42-
_handle_exception(e, 'PyTorch', logger)
43-
44-
45-
def check_env_torch():
46-
"""check PyTorch environment."""
47-
logger = get_logger('lmdeploy')
48-
49-
try:
50-
logger.debug('Checking <PyTorch> environment.')
51-
import torch
52-
53-
a = torch.tensor([1, 2], device='cuda')
54-
b = a.new_tensor([3, 4], device='cuda')
55-
c = a + b
56-
torch.testing.assert_close(c, a.new_tensor([4, 6]))
57-
except Exception as e:
58-
_handle_exception(e, 'PyTorch', logger)
59-
60-
61-
MAX_TRITON_VERSION = '3.0.0'
62-
63-
64-
def check_env_triton(device: str):
65-
"""check OpenAI Triton environment."""
66-
from packaging import version
67-
logger = get_logger('lmdeploy')
68-
69-
msg = (
70-
'Please ensure that your device is functioning properly with <Triton>.\n' # noqa: E501
71-
'You can verify your environment by running '
72-
'`python -m lmdeploy.pytorch.check_env.triton_custom_add`.')
73-
try:
74-
logger.debug('Checking <Triton> environment.')
75-
import torch
76-
import triton
77-
triton_version = version.parse(triton.__version__)
78-
if triton_version > version.parse(MAX_TRITON_VERSION):
79-
logger.warning(
80-
f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.')
81-
82-
from .triton_custom_add import custom_add
83-
a = torch.tensor([1, 2], device='cuda')
84-
b = a.new_tensor([3, 4], device='cuda')
85-
c = custom_add(a, b)
86-
torch.testing.assert_close(c, a + b)
87-
except RuntimeError as e:
88-
ptxas_error = 'device kernel image is invalid'
89-
if len(e.args) > 0 and ptxas_error in e.args[0]:
90-
msg = (
91-
'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501
92-
'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501
93-
' or reinstall the driver.')
94-
_handle_exception(e, 'Triton', logger, msg)
95-
except Exception as e:
96-
_handle_exception(e, 'Triton', logger, msg)
97-
98-
if device == 'cuda':
99-
device_cap = torch.cuda.get_device_capability()
100-
TRITON_VER_231 = version.parse('2.3.1')
101-
102-
if device_cap[0] <= 7:
103-
if triton_version <= TRITON_VER_231:
104-
err = RuntimeError(
105-
'Attention triton kernel does not fully support '
106-
'triton<3.0.0 on device with capability<8. '
107-
'Please upgrade your triton version.')
108-
_handle_exception(err, 'Triton', logger)
109-
110-
111-
def check_env(device_type: str):
112-
"""check all environment."""
113-
logger = get_logger('lmdeploy')
114-
logger.info('Checking environment for PyTorch Engine.')
13+
"""check Deeplink environment."""
11514
check_env_deeplink(device_type)
116-
check_env_torch()
117-
if device_type == 'cuda':
118-
check_env_triton('cuda')
119-
120-
121-
MIN_TRANSFORMERS_VERSION = '4.33.0'
122-
MAX_TRANSFORMERS_VERSION = '4.44.1'
123-
124-
125-
def check_awq(hf_config, device_type):
126-
"""check awq support."""
127-
logger = get_logger('lmdeploy')
128-
if device_type == 'cuda':
129-
quantization_config = getattr(hf_config, 'quantization_config', dict())
130-
quant_method = quantization_config.get('quant_method', None)
131-
if quant_method != 'awq':
132-
return
133-
try:
134-
import awq # noqa
135-
except Exception as e:
136-
_handle_exception(e, 'autoawq', logger)
137-
138-
try:
139-
import awq_ext # noqa
140-
except Exception:
141-
logger.debug('Exception:', exc_info=1)
142-
logger.warning('Failed to import `awq_ext`. '
143-
'Try reinstall it from source: '
144-
'https://github.com/casper-hansen/AutoAWQ_kernels')
145-
146-
147-
def check_transformers_version(model_path: str,
148-
trust_remote_code: bool = True,
149-
dtype: str = 'auto',
150-
device_type: str = 'cuda'):
151-
"""check transformers version."""
152-
from packaging import version
153-
logger = get_logger('lmdeploy')
154-
155-
def __check_transformers_version():
156-
"""check transformers version."""
157-
logger.debug('Checking <transformers> version.')
158-
trans_version = None
159-
try:
160-
import transformers
161-
trans_version = version.parse(transformers.__version__)
162-
min_version = version.parse(MIN_TRANSFORMERS_VERSION)
163-
max_version = version.parse(MAX_TRANSFORMERS_VERSION)
164-
if trans_version < min_version or trans_version > max_version:
165-
logger.warning('LMDeploy requires transformers version: '
166-
f'[{MIN_TRANSFORMERS_VERSION} ~ '
167-
f'{MAX_TRANSFORMERS_VERSION}], '
168-
'but found version: '
169-
f'{transformers.__version__}')
170-
except Exception as e:
171-
_handle_exception(e, 'transformers', logger)
172-
return transformers, trans_version
173-
174-
def __check_config(trans_version):
175-
"""check config."""
176-
logger.debug('Checking <Model> AutoConfig.from_pretrained.')
177-
try:
178-
from transformers import AutoConfig
179-
config = AutoConfig.from_pretrained(
180-
model_path, trust_remote_code=trust_remote_code)
181-
except Exception as e:
182-
message = (
183-
f'Load model config with transformers=={trans_version}'
184-
' failed. '
185-
'Please make sure model can be loaded with transformers API.')
186-
_handle_exception(e, 'transformers', logger, message=message)
187-
return config
188-
189-
def __check_model_transformers_version(config, trans_version):
190-
"""check model transformers version."""
191-
logger.debug('Checking <Model> required transformers version.')
192-
try:
193-
model_trans_version = getattr(config, 'transformers_version', None)
194-
if model_trans_version is not None:
195-
model_trans_version = version.parse(model_trans_version)
196-
assert trans_version >= model_trans_version, \
197-
'Version mismatch.'
198-
except Exception as e:
199-
message = (f'model `{model_path}` requires '
200-
f'transformers version {model_trans_version} '
201-
f'but transformers {trans_version} is installed.')
202-
_handle_exception(e, 'transformers', logger, message=message)
203-
204-
def __check_model_dtype_support(config, device_type):
205-
"""Checking model dtype support."""
206-
logger.debug('Checking <Model> dtype support.')
207-
208-
import torch
209-
210-
from lmdeploy.pytorch.config import ModelConfig
211-
from lmdeploy.utils import is_bf16_supported
212-
213-
try:
214-
model_config = ModelConfig.from_hf_config(config,
215-
model_path=model_path,
216-
dtype=dtype)
217-
if model_config.dtype == torch.bfloat16:
218-
assert is_bf16_supported(device_type), (
219-
'bf16 is not supported on your device')
220-
except AssertionError as e:
221-
message = (
222-
f'Your device does not support `{model_config.dtype}`. '
223-
'You can set `dtype` to float16 in PyTorchEngineConfig or '
224-
'`--dtype float16` to api_server.\n'
225-
'Note that this might have negative effect!')
226-
_handle_exception(e, 'Model', logger, message=message)
227-
except Exception as e:
228-
message = (f'Checking failed with error {e}',
229-
'Please send issue to LMDeploy with error logs.')
230-
_handle_exception(e, 'Model', logger, message=message)
231-
232-
return model_config
233-
234-
_, trans_version = __check_transformers_version()
235-
config = __check_config(trans_version)
236-
__check_model_transformers_version(config, trans_version)
237-
__check_model_dtype_support(config, device_type)
238-
check_awq(config, device_type)
239-
240-
241-
def check_model(model_path: str,
242-
trust_remote_code: bool = True,
243-
dtype: str = 'auto',
244-
device_type: str = 'cuda'):
245-
"""check model requirements."""
246-
logger = get_logger('lmdeploy')
247-
logger.info('Checking model.')
248-
check_transformers_version(model_path, trust_remote_code, dtype,
249-
device_type)
250-
251-
252-
def check_adapter(path: str):
253-
"""check adapter."""
254-
logger = get_logger('lmdeploy')
255-
logger.debug(f'Checking <Adapter>: {path}.')
256-
257-
try:
258-
from peft import PeftConfig
259-
PeftConfig.from_pretrained(path)
260-
except Exception as e:
261-
message = ('Please make sure the adapter can be loaded with '
262-
'`peft.PeftConfig.from_pretrained`\n')
263-
err_msg = '' if len(e.args) == 0 else e.args[0]
264-
if 'got an unexpected keyword argument' in err_msg:
265-
message += ('Or try remove all unexpected keywords '
266-
'in `adapter_config.json`.')
267-
_handle_exception(e, 'Model', logger, message=message)
268-
269-
270-
def check_adapters(adapter_paths: List[str]):
271-
"""check adapters."""
272-
if len(adapter_paths) <= 0:
273-
return
274-
logger = get_logger('lmdeploy')
275-
logger.info('Checking adapters.')
276-
for path in adapter_paths:
277-
check_adapter(path)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from .base import BaseChecker
3+
4+
5+
class AdapterChecker(BaseChecker):
6+
"""check adapter is available."""
7+
8+
def __init__(self, adapter_path: str, logger=None):
9+
super().__init__(logger)
10+
self.adapter_path = adapter_path
11+
12+
def check(self):
13+
"""check."""
14+
path = self.adapter_path
15+
16+
try:
17+
import peft # noqa: F401
18+
except Exception as e:
19+
self.log_and_exit(e, 'Adapter', message='Failed to import peft.')
20+
21+
try:
22+
from peft import PeftConfig
23+
PeftConfig.from_pretrained(path)
24+
except Exception as e:
25+
message = ('Please make sure the adapter can be loaded with '
26+
'`peft.PeftConfig.from_pretrained`\n')
27+
err_msg = '' if len(e.args) == 0 else e.args[0]
28+
if 'got an unexpected keyword argument' in err_msg:
29+
message += ('Or try remove all unexpected keywords '
30+
'in `adapter_config.json`.')
31+
self.log_and_exit(e, 'Adapter', message=message)

lmdeploy/pytorch/check_env/base.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from logging import Logger
3+
from typing import List
4+
5+
from lmdeploy.utils import get_logger
6+
7+
RED_COLOR = '\033[31m'
8+
RESET_COLOR = '\033[0m'
9+
10+
11+
def _red_text(text: str):
12+
"""red text."""
13+
return f'{RED_COLOR}{text}{RESET_COLOR}'
14+
15+
16+
class BaseChecker:
17+
"""base checker."""
18+
19+
def __init__(self, logger: Logger = None):
20+
if logger is None:
21+
logger = get_logger('lmdeploy')
22+
self.logger = logger
23+
self._is_passed = False
24+
self._required_checker: List[BaseChecker] = list()
25+
26+
def get_logger(self):
27+
"""get logger."""
28+
return self.logger
29+
30+
def register_required_checker(self, checker: 'BaseChecker'):
31+
"""register_required."""
32+
self._required_checker.append(checker)
33+
34+
def handle(self):
35+
"""handle check."""
36+
is_passed = getattr(self, '_is_passed', False)
37+
if not is_passed:
38+
checker_name = type(self).__name__
39+
self.logger.debug(f'Checking <{checker_name}>:')
40+
for checker in self._required_checker:
41+
checker.handle()
42+
self.check()
43+
self.is_passed = True
44+
45+
def log_and_exit(self,
46+
e: Exception = None,
47+
mod_name: str = None,
48+
message: str = None):
49+
logger = self.logger
50+
if mod_name is None:
51+
mod_name = type(self).__name__
52+
if message is None:
53+
message = 'Please check your environment.'
54+
logger.debug('Exception', exc_info=1)
55+
if e is not None:
56+
logger.error(f'{type(e).__name__}: {e}')
57+
logger.error(f'<{mod_name}> check failed!\n{_red_text(message)}')
58+
exit(1)
59+
60+
def check(self):
61+
"""check."""
62+
raise NotImplementedError('check not implemented.')

0 commit comments

Comments
 (0)