Skip to content

Commit b88df3f

Browse files
authored
fix pipeline init turbomind from workspace (#1126)
* fix pipeline init turbomind from workspace * stricter check
1 parent d28e98e commit b88df3f

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

lmdeploy/serve/async_engine.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ def _build_turbomind(
100100
**kwargs):
101101
"""Innter build method for turbomind backend."""
102102
self.model_name = model_name
103+
# model mayebe from workspace
104+
from lmdeploy.turbomind.utils import \
105+
get_model_name_from_workspace_model
106+
if self.model_name is None:
107+
self.model_name = get_model_name_from_workspace_model(model_path)
103108
# try fuzzy matching to get a model_name
104109
if self.model_name is None and (backend_config is None
105110
or backend_config.model_name == ''
@@ -113,7 +118,8 @@ def _build_turbomind(
113118
logger.warning(
114119
f'Best matched chat template name: {self.model_name}')
115120
elif self.model_name is not None and backend_config is not None:
116-
if self.model_name != backend_config.model_name:
121+
if backend_config.model_name is not None \
122+
and self.model_name != backend_config.model_name:
117123
raise ArgumentError(
118124
f'Got different model names from model_name = '
119125
f'{self.model_name}, backend_config = {backend_config}')

lmdeploy/turbomind/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,20 @@ def get_gen_param(cap,
117117
return gen_param
118118

119119

120+
def get_model_name_from_workspace_model(model_dir: str):
121+
"""Get model name from workspace model."""
122+
from configparser import ConfigParser
123+
triton_model_path = os.path.join(model_dir, 'triton_models', 'weights')
124+
if not os.path.exists(triton_model_path):
125+
return None
126+
ini_path = os.path.join(triton_model_path, 'config.ini')
127+
# load cfg
128+
with open(ini_path, 'r') as f:
129+
parser = ConfigParser()
130+
parser.read_file(f)
131+
return parser['llama']['model_name']
132+
133+
120134
def get_model_from_config(model_dir: str):
121135
import json
122136
config_file = os.path.join(model_dir, 'config.json')

0 commit comments

Comments
 (0)