|
1 | 1 | # Copyright (c) OpenMMLab. All rights reserved. |
2 | | -from logging import Logger |
3 | | -from typing import List |
4 | | - |
5 | | -from lmdeploy.utils import get_logger |
6 | | - |
7 | | - |
8 | | -def _handle_exception(e: Exception, |
9 | | - mod_name: str, |
10 | | - logger: Logger, |
11 | | - message: str = None): |
12 | | - red_color = '\033[31m' |
13 | | - reset_color = '\033[0m' |
14 | | - if message is None: |
15 | | - message = 'Please ensure it has been installed correctly.' |
16 | | - logger.debug('Exception', exc_info=1) |
17 | | - logger.error(f'{type(e).__name__}: {e}') |
18 | | - logger.error(f'{red_color}' |
19 | | - f'<{mod_name}> test failed!\n' |
20 | | - f'{message}' |
21 | | - f'{reset_color}') |
22 | | - exit(1) |
| 2 | +from .base import BaseChecker # noqa: F401 |
23 | 3 |
|
24 | 4 |
|
25 | 5 | def check_env_deeplink(device_type: str): |
26 | 6 | """check Deeplink environment.""" |
27 | | - try_import_deeplink(device_type) |
| 7 | + from .deeplink import DeeplinkChecker |
| 8 | + checker = DeeplinkChecker(device_type) |
| 9 | + checker.handle() |
28 | 10 |
|
29 | 11 |
|
30 | 12 | def try_import_deeplink(device_type: str): |
31 | | - """import dlinfer if specific device_type is set.""" |
32 | | - deeplink_device_type_list = [ |
33 | | - 'ascend', |
34 | | - 'npu', |
35 | | - 'maca', |
36 | | - ] |
37 | | - if device_type in deeplink_device_type_list: |
38 | | - logger = get_logger('lmdeploy') |
39 | | - try: |
40 | | - import dlinfer.framework.lmdeploy_ext # noqa: F401 |
41 | | - except Exception as e: |
42 | | - _handle_exception(e, 'PyTorch', logger) |
43 | | - |
44 | | - |
45 | | -def check_env_torch(): |
46 | | - """check PyTorch environment.""" |
47 | | - logger = get_logger('lmdeploy') |
48 | | - |
49 | | - try: |
50 | | - logger.debug('Checking <PyTorch> environment.') |
51 | | - import torch |
52 | | - |
53 | | - a = torch.tensor([1, 2], device='cuda') |
54 | | - b = a.new_tensor([3, 4], device='cuda') |
55 | | - c = a + b |
56 | | - torch.testing.assert_close(c, a.new_tensor([4, 6])) |
57 | | - except Exception as e: |
58 | | - _handle_exception(e, 'PyTorch', logger) |
59 | | - |
60 | | - |
61 | | -MAX_TRITON_VERSION = '3.0.0' |
62 | | - |
63 | | - |
64 | | -def check_env_triton(device: str): |
65 | | - """check OpenAI Triton environment.""" |
66 | | - from packaging import version |
67 | | - logger = get_logger('lmdeploy') |
68 | | - |
69 | | - msg = ( |
70 | | - 'Please ensure that your device is functioning properly with <Triton>.\n' # noqa: E501 |
71 | | - 'You can verify your environment by running ' |
72 | | - '`python -m lmdeploy.pytorch.check_env.triton_custom_add`.') |
73 | | - try: |
74 | | - logger.debug('Checking <Triton> environment.') |
75 | | - import torch |
76 | | - import triton |
77 | | - triton_version = version.parse(triton.__version__) |
78 | | - if triton_version > version.parse(MAX_TRITON_VERSION): |
79 | | - logger.warning( |
80 | | - f'Engine has not been tested on triton>{MAX_TRITON_VERSION}.') |
81 | | - |
82 | | - from .triton_custom_add import custom_add |
83 | | - a = torch.tensor([1, 2], device='cuda') |
84 | | - b = a.new_tensor([3, 4], device='cuda') |
85 | | - c = custom_add(a, b) |
86 | | - torch.testing.assert_close(c, a + b) |
87 | | - except RuntimeError as e: |
88 | | - ptxas_error = 'device kernel image is invalid' |
89 | | - if len(e.args) > 0 and ptxas_error in e.args[0]: |
90 | | - msg = ( |
91 | | - 'This Error might caused by mismatching between NVIDIA Driver and nvcc compiler. \n' # noqa: E501 |
92 | | - 'Try solution https://github.com/triton-lang/triton/issues/1955#issuecomment-1929908209' # noqa: E501 |
93 | | - ' or reinstall the driver.') |
94 | | - _handle_exception(e, 'Triton', logger, msg) |
95 | | - except Exception as e: |
96 | | - _handle_exception(e, 'Triton', logger, msg) |
97 | | - |
98 | | - if device == 'cuda': |
99 | | - device_cap = torch.cuda.get_device_capability() |
100 | | - TRITON_VER_231 = version.parse('2.3.1') |
101 | | - |
102 | | - if device_cap[0] <= 7: |
103 | | - if triton_version <= TRITON_VER_231: |
104 | | - err = RuntimeError( |
105 | | - 'Attention triton kernel does not fully support ' |
106 | | - 'triton<3.0.0 on device with capability<8. ' |
107 | | - 'Please upgrade your triton version.') |
108 | | - _handle_exception(err, 'Triton', logger) |
109 | | - |
110 | | - |
111 | | -def check_env(device_type: str): |
112 | | - """check all environment.""" |
113 | | - logger = get_logger('lmdeploy') |
114 | | - logger.info('Checking environment for PyTorch Engine.') |
| 13 | + """check Deeplink environment.""" |
115 | 14 | check_env_deeplink(device_type) |
116 | | - check_env_torch() |
117 | | - if device_type == 'cuda': |
118 | | - check_env_triton('cuda') |
119 | | - |
120 | | - |
121 | | -MIN_TRANSFORMERS_VERSION = '4.33.0' |
122 | | -MAX_TRANSFORMERS_VERSION = '4.44.1' |
123 | | - |
124 | | - |
125 | | -def check_awq(hf_config, device_type): |
126 | | - """check awq support.""" |
127 | | - logger = get_logger('lmdeploy') |
128 | | - if device_type == 'cuda': |
129 | | - quantization_config = getattr(hf_config, 'quantization_config', dict()) |
130 | | - quant_method = quantization_config.get('quant_method', None) |
131 | | - if quant_method != 'awq': |
132 | | - return |
133 | | - try: |
134 | | - import awq # noqa |
135 | | - except Exception as e: |
136 | | - _handle_exception(e, 'autoawq', logger) |
137 | | - |
138 | | - try: |
139 | | - import awq_ext # noqa |
140 | | - except Exception: |
141 | | - logger.debug('Exception:', exc_info=1) |
142 | | - logger.warning('Failed to import `awq_ext`. ' |
143 | | - 'Try reinstall it from source: ' |
144 | | - 'https://github.com/casper-hansen/AutoAWQ_kernels') |
145 | | - |
146 | | - |
147 | | -def check_transformers_version(model_path: str, |
148 | | - trust_remote_code: bool = True, |
149 | | - dtype: str = 'auto', |
150 | | - device_type: str = 'cuda'): |
151 | | - """check transformers version.""" |
152 | | - from packaging import version |
153 | | - logger = get_logger('lmdeploy') |
154 | | - |
155 | | - def __check_transformers_version(): |
156 | | - """check transformers version.""" |
157 | | - logger.debug('Checking <transformers> version.') |
158 | | - trans_version = None |
159 | | - try: |
160 | | - import transformers |
161 | | - trans_version = version.parse(transformers.__version__) |
162 | | - min_version = version.parse(MIN_TRANSFORMERS_VERSION) |
163 | | - max_version = version.parse(MAX_TRANSFORMERS_VERSION) |
164 | | - if trans_version < min_version or trans_version > max_version: |
165 | | - logger.warning('LMDeploy requires transformers version: ' |
166 | | - f'[{MIN_TRANSFORMERS_VERSION} ~ ' |
167 | | - f'{MAX_TRANSFORMERS_VERSION}], ' |
168 | | - 'but found version: ' |
169 | | - f'{transformers.__version__}') |
170 | | - except Exception as e: |
171 | | - _handle_exception(e, 'transformers', logger) |
172 | | - return transformers, trans_version |
173 | | - |
174 | | - def __check_config(trans_version): |
175 | | - """check config.""" |
176 | | - logger.debug('Checking <Model> AutoConfig.from_pretrained.') |
177 | | - try: |
178 | | - from transformers import AutoConfig |
179 | | - config = AutoConfig.from_pretrained( |
180 | | - model_path, trust_remote_code=trust_remote_code) |
181 | | - except Exception as e: |
182 | | - message = ( |
183 | | - f'Load model config with transformers=={trans_version}' |
184 | | - ' failed. ' |
185 | | - 'Please make sure model can be loaded with transformers API.') |
186 | | - _handle_exception(e, 'transformers', logger, message=message) |
187 | | - return config |
188 | | - |
189 | | - def __check_model_transformers_version(config, trans_version): |
190 | | - """check model transformers version.""" |
191 | | - logger.debug('Checking <Model> required transformers version.') |
192 | | - try: |
193 | | - model_trans_version = getattr(config, 'transformers_version', None) |
194 | | - if model_trans_version is not None: |
195 | | - model_trans_version = version.parse(model_trans_version) |
196 | | - assert trans_version >= model_trans_version, \ |
197 | | - 'Version mismatch.' |
198 | | - except Exception as e: |
199 | | - message = (f'model `{model_path}` requires ' |
200 | | - f'transformers version {model_trans_version} ' |
201 | | - f'but transformers {trans_version} is installed.') |
202 | | - _handle_exception(e, 'transformers', logger, message=message) |
203 | | - |
204 | | - def __check_model_dtype_support(config, device_type): |
205 | | - """Checking model dtype support.""" |
206 | | - logger.debug('Checking <Model> dtype support.') |
207 | | - |
208 | | - import torch |
209 | | - |
210 | | - from lmdeploy.pytorch.config import ModelConfig |
211 | | - from lmdeploy.utils import is_bf16_supported |
212 | | - |
213 | | - try: |
214 | | - model_config = ModelConfig.from_hf_config(config, |
215 | | - model_path=model_path, |
216 | | - dtype=dtype) |
217 | | - if model_config.dtype == torch.bfloat16: |
218 | | - assert is_bf16_supported(device_type), ( |
219 | | - 'bf16 is not supported on your device') |
220 | | - except AssertionError as e: |
221 | | - message = ( |
222 | | - f'Your device does not support `{model_config.dtype}`. ' |
223 | | - 'You can set `dtype` to float16 in PyTorchEngineConfig or ' |
224 | | - '`--dtype float16` to api_server.\n' |
225 | | - 'Note that this might have negative effect!') |
226 | | - _handle_exception(e, 'Model', logger, message=message) |
227 | | - except Exception as e: |
228 | | - message = (f'Checking failed with error {e}', |
229 | | - 'Please send issue to LMDeploy with error logs.') |
230 | | - _handle_exception(e, 'Model', logger, message=message) |
231 | | - |
232 | | - return model_config |
233 | | - |
234 | | - _, trans_version = __check_transformers_version() |
235 | | - config = __check_config(trans_version) |
236 | | - __check_model_transformers_version(config, trans_version) |
237 | | - __check_model_dtype_support(config, device_type) |
238 | | - check_awq(config, device_type) |
239 | | - |
240 | | - |
241 | | -def check_model(model_path: str, |
242 | | - trust_remote_code: bool = True, |
243 | | - dtype: str = 'auto', |
244 | | - device_type: str = 'cuda'): |
245 | | - """check model requirements.""" |
246 | | - logger = get_logger('lmdeploy') |
247 | | - logger.info('Checking model.') |
248 | | - check_transformers_version(model_path, trust_remote_code, dtype, |
249 | | - device_type) |
250 | | - |
251 | | - |
252 | | -def check_adapter(path: str): |
253 | | - """check adapter.""" |
254 | | - logger = get_logger('lmdeploy') |
255 | | - logger.debug(f'Checking <Adapter>: {path}.') |
256 | | - |
257 | | - try: |
258 | | - from peft import PeftConfig |
259 | | - PeftConfig.from_pretrained(path) |
260 | | - except Exception as e: |
261 | | - message = ('Please make sure the adapter can be loaded with ' |
262 | | - '`peft.PeftConfig.from_pretrained`\n') |
263 | | - err_msg = '' if len(e.args) == 0 else e.args[0] |
264 | | - if 'got an unexpected keyword argument' in err_msg: |
265 | | - message += ('Or try remove all unexpected keywords ' |
266 | | - 'in `adapter_config.json`.') |
267 | | - _handle_exception(e, 'Model', logger, message=message) |
268 | | - |
269 | | - |
270 | | -def check_adapters(adapter_paths: List[str]): |
271 | | - """check adapters.""" |
272 | | - if len(adapter_paths) <= 0: |
273 | | - return |
274 | | - logger = get_logger('lmdeploy') |
275 | | - logger.info('Checking adapters.') |
276 | | - for path in adapter_paths: |
277 | | - check_adapter(path) |
0 commit comments