Skip to content

Commit ec7af93

Browse files
authored
fix: Extend add_tensor_model so that ModelDeploymentCard can be correctly picked up (#4169)
Signed-off-by: zhongdaor <[email protected]>
1 parent f30d76c commit ec7af93

File tree

3 files changed

+178
-2
lines changed

3 files changed

+178
-2
lines changed

lib/bindings/python/rust/kserve_grpc.rs

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33

44
use std::sync::Arc;
55

6+
use dynamo_llm::{self as llm_rs};
7+
use llm_rs::model_card::ModelDeploymentCard as RsModelDeploymentCard;
8+
use llm_rs::model_type::{ModelInput, ModelType};
69
use pyo3::prelude::*;
710

8-
use crate::{CancellationToken, engine::*, to_pyerr};
11+
use crate::{CancellationToken, engine::*, llm::local_model::ModelRuntimeConfig, to_pyerr};
912

1013
pub use dynamo_llm::grpc::service::kserve;
1114

@@ -56,12 +59,28 @@ impl KserveGrpcService {
5659
.map_err(to_pyerr)
5760
}
5861

62+
#[pyo3(signature = (model, checksum, engine, runtime_config=None))]
5963
pub fn add_tensor_model(
6064
&self,
6165
model: String,
6266
checksum: String,
6367
engine: PythonAsyncEngine,
68+
runtime_config: Option<ModelRuntimeConfig>,
6469
) -> PyResult<()> {
70+
// If runtime_config is provided, create and save a ModelDeploymentCard
71+
// so the ModelConfig endpoint can return model configuration
72+
if let Some(runtime_config) = runtime_config {
73+
let mut card = RsModelDeploymentCard::with_name_only(&model);
74+
card.model_type = ModelType::TensorBased;
75+
card.model_input = ModelInput::Tensor;
76+
card.runtime_config = runtime_config.inner;
77+
78+
self.inner
79+
.model_manager()
80+
.save_model_card(&model, card)
81+
.map_err(to_pyerr)?;
82+
}
83+
6584
let engine = Arc::new(engine);
6685
self.inner
6786
.model_manager()
@@ -84,10 +103,17 @@ impl KserveGrpcService {
84103
}
85104

86105
pub fn remove_tensor_model(&self, model: String) -> PyResult<()> {
106+
// Remove the engine
87107
self.inner
88108
.model_manager()
89109
.remove_tensor_model(&model)
90-
.map_err(to_pyerr)
110+
.map_err(to_pyerr)?;
111+
112+
// Also remove the model card if it exists
113+
// (It's ok if it doesn't exist since runtime_config is optional, we just ignore the None return)
114+
let _ = self.inner.model_manager().remove_model_card(&model);
115+
116+
Ok(())
91117
}
92118

93119
pub fn list_chat_completions_models(&self) -> PyResult<Vec<String>> {

lib/bindings/python/src/dynamo/_core.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,7 @@ class KserveGrpcService:
894894
model: str,
895895
checksum: str,
896896
engine: PythonAsyncEngine,
897+
runtime_config: Optional[ModelRuntimeConfig],
897898
) -> None:
898899
"""
899900
Register a tensor-based model with the service.
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
import asyncio
5+
import contextlib
6+
from contextlib import asynccontextmanager
7+
from typing import Any, AsyncIterator, Optional, Tuple
8+
9+
import pytest
10+
import tritonclient.grpc.model_config_pb2 as mc
11+
from tritonclient.utils import InferenceServerException
12+
13+
from dynamo.llm import KserveGrpcService, ModelRuntimeConfig, PythonAsyncEngine
14+
15+
pytestmark = pytest.mark.pre_merge
16+
17+
18+
async def _fetch_model_config(
19+
client,
20+
model_name: str,
21+
retries: int = 30,
22+
) -> Any:
23+
last_error: Optional[Exception] = None
24+
for _ in range(retries):
25+
try:
26+
return await asyncio.to_thread(client.get_model_config, model_name)
27+
except InferenceServerException as err:
28+
last_error = err
29+
await asyncio.sleep(0.1)
30+
raise AssertionError(
31+
f"Unable to fetch model config for '{model_name}': {last_error}"
32+
)
33+
34+
35+
class EchoTensorEngine:
36+
"""Minimal tensor engine stub for registering tensor models."""
37+
38+
def __init__(self, model_name: str):
39+
self._model_name = model_name
40+
41+
def generate(self, request, context=None):
42+
async def _generator():
43+
yield {
44+
"model": self._model_name,
45+
"tensors": request.get("tensors", []),
46+
"parameters": request.get("parameters", {}),
47+
}
48+
49+
return _generator()
50+
51+
52+
@pytest.fixture
53+
def tensor_service(runtime):
54+
@asynccontextmanager
55+
async def _start(
56+
model_name: str,
57+
*,
58+
runtime_config: Optional[ModelRuntimeConfig] = None,
59+
checksum: str = "dummy-mdcsum",
60+
) -> AsyncIterator[Tuple[str, int]]:
61+
host = "127.0.0.1"
62+
port = 8787
63+
loop = asyncio.get_running_loop()
64+
engine = PythonAsyncEngine(EchoTensorEngine(model_name).generate, loop)
65+
tensor_model_service = KserveGrpcService(port=port, host=host)
66+
67+
tensor_model_service.add_tensor_model(
68+
model_name, checksum, engine, runtime_config=runtime_config
69+
)
70+
71+
cancel_token = runtime.child_token()
72+
73+
async def _serve():
74+
await tensor_model_service.run(cancel_token)
75+
76+
server_task = asyncio.create_task(_serve())
77+
try:
78+
await asyncio.sleep(1) # wait service to start
79+
yield host, port
80+
finally:
81+
cancel_token.cancel()
82+
with contextlib.suppress(asyncio.TimeoutError, asyncio.CancelledError):
83+
await asyncio.wait_for(server_task, timeout=5)
84+
85+
return _start
86+
87+
88+
@pytest.mark.asyncio
89+
@pytest.mark.forked
90+
async def test_model_config_uses_runtime_config(tensor_service):
91+
"""Ensure tensor runtime_config is returned via the ModelConfig endpoint."""
92+
import tritonclient.grpc as grpcclient
93+
94+
model_name = "tensor-config-model"
95+
tensor_config = {
96+
"name": model_name,
97+
"inputs": [
98+
{"name": "input_text", "data_type": "Bytes", "shape": [-1]},
99+
{"name": "control_flag", "data_type": "Bool", "shape": [1]},
100+
],
101+
"outputs": [
102+
{"name": "results", "data_type": "Bytes", "shape": [-1]},
103+
],
104+
}
105+
runtime_config = ModelRuntimeConfig()
106+
runtime_config.set_tensor_model_config(tensor_config)
107+
108+
async with tensor_service(model_name, runtime_config=runtime_config) as (
109+
host,
110+
port,
111+
):
112+
client = grpcclient.InferenceServerClient(url=f"{host}:{port}")
113+
try:
114+
response = await _fetch_model_config(client, model_name)
115+
finally:
116+
client.close()
117+
118+
model_config = response.config
119+
assert model_config.name == model_name
120+
assert model_config.platform == "dynamo"
121+
assert model_config.backend == "dynamo"
122+
123+
inputs = {spec.name: spec for spec in model_config.input}
124+
assert list(inputs["input_text"].dims) == [-1]
125+
assert inputs["input_text"].data_type == mc.TYPE_STRING
126+
assert list(inputs["control_flag"].dims) == [1]
127+
assert inputs["control_flag"].data_type == mc.TYPE_BOOL
128+
129+
outputs = {spec.name: spec for spec in model_config.output}
130+
assert list(outputs["results"].dims) == [-1]
131+
assert outputs["results"].data_type == mc.TYPE_STRING
132+
133+
134+
@pytest.mark.asyncio
135+
@pytest.mark.forked
136+
async def test_model_config_missing_runtime_config_errors(tensor_service):
137+
"""ModelConfig should return NOT_FOUND when no tensor runtime_config is saved."""
138+
model_name = "tensor-config-missing"
139+
import tritonclient.grpc as grpcclient
140+
141+
async with tensor_service(model_name, runtime_config=None) as (host, port):
142+
client = grpcclient.InferenceServerClient(url=f"{host}:{port}")
143+
try:
144+
with pytest.raises(InferenceServerException) as excinfo:
145+
await asyncio.to_thread(client.get_model_config, model_name)
146+
finally:
147+
client.close()
148+
149+
assert "not found" in str(excinfo.value).lower()

0 commit comments

Comments
 (0)