|
4 | 4 |
|
5 | 5 | from setuptools import find_packages, setup |
6 | 6 |
|
7 | | -npu_available = False |
8 | | -try: |
9 | | - import torch_npu |
10 | | - |
11 | | - npu_available = torch_npu.npu.is_available() |
12 | | -except ImportError: |
13 | | - pass |
14 | | - |
15 | 7 | pwd = os.path.dirname(__file__) |
16 | 8 | version_file = 'lmdeploy/version.py' |
17 | 9 |
|
18 | 10 |
|
| 11 | +def get_target_device(): |
| 12 | + return os.getenv('LMDEPLOY_TARGET_DEVICE', 'cuda') |
| 13 | + |
| 14 | + |
19 | 15 | def readme(): |
20 | 16 | with open(os.path.join(pwd, 'README.md'), encoding='utf-8') as f: |
21 | 17 | content = f.read() |
@@ -154,16 +150,12 @@ def gen_packages_items(): |
154 | 150 | setup_requires=parse_requirements('requirements/build.txt'), |
155 | 151 | tests_require=parse_requirements('requirements/test.txt'), |
156 | 152 | install_requires=parse_requirements( |
157 | | - 'requirements/runtime_ascend.txt' |
158 | | - if npu_available else 'requirements/runtime.txt'), |
| 153 | + f'requirements/runtime_{get_target_device()}.txt'), |
159 | 154 | extras_require={ |
160 | 155 | 'all': |
161 | | - parse_requirements('requirements_ascend.txt' |
162 | | - if npu_available else 'requirements.txt'), |
163 | | - 'lite': |
164 | | - parse_requirements('requirements/lite.txt'), |
165 | | - 'serve': |
166 | | - parse_requirements('requirements/serve.txt') |
| 156 | + parse_requirements(f'requirements_{get_target_device()}.txt'), |
| 157 | + 'lite': parse_requirements('requirements/lite.txt'), |
| 158 | + 'serve': parse_requirements('requirements/serve.txt') |
167 | 159 | }, |
168 | 160 | has_ext_modules=check_ext_modules, |
169 | 161 | classifiers=[ |
|
0 commit comments