Skip to content

Commit 71fdd1d

Browse files
committed
feat: allow users to add constraints and validators to ToolParam
1 parent 343929d commit 71fdd1d

File tree

10 files changed

+458
-49
lines changed

10 files changed

+458
-49
lines changed

protos/logical_plan/v1/tools.proto

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@ import "logical_plan/v1/datatypes.proto";
66
import "logical_plan/v1/complex_types.proto";
77
import "logical_plan/v1/plans.proto";
88

9+
message NumericConstraint {
10+
oneof kind {
11+
sint32 int_value = 1;
12+
float float_value = 2;
13+
}
14+
}
15+
16+
message ToolParameterConstraints {
17+
optional NumericConstraint gt = 1;
18+
optional NumericConstraint ge = 2;
19+
optional NumericConstraint lt = 3;
20+
optional NumericConstraint le = 4;
21+
optional NumericConstraint multiple_of = 5;
22+
23+
optional uint32 min_length = 6;
24+
optional uint32 max_length = 7;
25+
optional string pattern = 8;
26+
}
927

1028
message ToolParameter {
1129
string name = 1;
@@ -15,6 +33,8 @@ message ToolParameter {
1533
bool has_default = 5;
1634
optional ScalarValue default_value = 6;
1735
repeated ScalarValue allowed_values = 7;
36+
optional ToolParameterConstraints constraints = 8;
37+
repeated string validator_names = 9;
1838
}
1939

2040
message ToolDefinition {

src/fenic/_gen/protos/logical_plan/v1/tools_pb2.py

Lines changed: 9 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/fenic/_gen/protos/logical_plan/v1/tools_pb2.pyi

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,55 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map
88

99
DESCRIPTOR: _descriptor.FileDescriptor
1010

11+
class NumericConstraint(_message.Message):
12+
__slots__ = ("int_value", "float_value")
13+
INT_VALUE_FIELD_NUMBER: _ClassVar[int]
14+
FLOAT_VALUE_FIELD_NUMBER: _ClassVar[int]
15+
int_value: int
16+
float_value: float
17+
def __init__(self, int_value: _Optional[int] = ..., float_value: _Optional[float] = ...) -> None: ...
18+
19+
class ToolParameterConstraints(_message.Message):
20+
__slots__ = ("gt", "ge", "lt", "le", "multiple_of", "min_length", "max_length", "pattern")
21+
GT_FIELD_NUMBER: _ClassVar[int]
22+
GE_FIELD_NUMBER: _ClassVar[int]
23+
LT_FIELD_NUMBER: _ClassVar[int]
24+
LE_FIELD_NUMBER: _ClassVar[int]
25+
MULTIPLE_OF_FIELD_NUMBER: _ClassVar[int]
26+
MIN_LENGTH_FIELD_NUMBER: _ClassVar[int]
27+
MAX_LENGTH_FIELD_NUMBER: _ClassVar[int]
28+
PATTERN_FIELD_NUMBER: _ClassVar[int]
29+
gt: NumericConstraint
30+
ge: NumericConstraint
31+
lt: NumericConstraint
32+
le: NumericConstraint
33+
multiple_of: NumericConstraint
34+
min_length: int
35+
max_length: int
36+
pattern: str
37+
def __init__(self, gt: _Optional[_Union[NumericConstraint, _Mapping]] = ..., ge: _Optional[_Union[NumericConstraint, _Mapping]] = ..., lt: _Optional[_Union[NumericConstraint, _Mapping]] = ..., le: _Optional[_Union[NumericConstraint, _Mapping]] = ..., multiple_of: _Optional[_Union[NumericConstraint, _Mapping]] = ..., min_length: _Optional[int] = ..., max_length: _Optional[int] = ..., pattern: _Optional[str] = ...) -> None: ...
38+
1139
class ToolParameter(_message.Message):
12-
__slots__ = ("name", "description", "data_type", "required", "has_default", "default_value", "allowed_values")
40+
__slots__ = ("name", "description", "data_type", "required", "has_default", "default_value", "allowed_values", "constraints", "validator_names")
1341
NAME_FIELD_NUMBER: _ClassVar[int]
1442
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
1543
DATA_TYPE_FIELD_NUMBER: _ClassVar[int]
1644
REQUIRED_FIELD_NUMBER: _ClassVar[int]
1745
HAS_DEFAULT_FIELD_NUMBER: _ClassVar[int]
1846
DEFAULT_VALUE_FIELD_NUMBER: _ClassVar[int]
1947
ALLOWED_VALUES_FIELD_NUMBER: _ClassVar[int]
48+
CONSTRAINTS_FIELD_NUMBER: _ClassVar[int]
49+
VALIDATOR_NAMES_FIELD_NUMBER: _ClassVar[int]
2050
name: str
2151
description: str
2252
data_type: _datatypes_pb2.DataType
2353
required: bool
2454
has_default: bool
2555
default_value: _complex_types_pb2.ScalarValue
2656
allowed_values: _containers.RepeatedCompositeFieldContainer[_complex_types_pb2.ScalarValue]
27-
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., data_type: _Optional[_Union[_datatypes_pb2.DataType, _Mapping]] = ..., required: bool = ..., has_default: bool = ..., default_value: _Optional[_Union[_complex_types_pb2.ScalarValue, _Mapping]] = ..., allowed_values: _Optional[_Iterable[_Union[_complex_types_pb2.ScalarValue, _Mapping]]] = ...) -> None: ...
57+
constraints: ToolParameterConstraints
58+
validator_names: _containers.RepeatedScalarFieldContainer[str]
59+
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., data_type: _Optional[_Union[_datatypes_pb2.DataType, _Mapping]] = ..., required: bool = ..., has_default: bool = ..., default_value: _Optional[_Union[_complex_types_pb2.ScalarValue, _Mapping]] = ..., allowed_values: _Optional[_Iterable[_Union[_complex_types_pb2.ScalarValue, _Mapping]]] = ..., constraints: _Optional[_Union[ToolParameterConstraints, _Mapping]] = ..., validator_names: _Optional[_Iterable[str]] = ...) -> None: ...
2860

2961
class ToolDefinition(_message.Message):
3062
__slots__ = ("name", "description", "params", "parameterized_view", "result_limit")

src/fenic/core/_serde/proto/serde_context.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
FenicSchemaProto,
4040
LogicalExprProto,
4141
LogicalPlanProto,
42+
NumericConstraintProto,
4243
NumpyArrayProto,
4344
ResolvedClassDefinitionProto,
4445
ResolvedModelAliasProto,
@@ -48,9 +49,15 @@
4849
ScalarStructProto,
4950
ScalarValueProto,
5051
ToolDefinitionProto,
52+
ToolParameterConstraintsProto,
5153
ToolParameterProto,
5254
)
53-
from fenic.core.mcp.types import BoundToolParam, ParameterizedToolDefinition
55+
from fenic.core.mcp._validators import get_param_validator
56+
from fenic.core.mcp.types import (
57+
BoundToolParam,
58+
ParameterizedToolDefinition,
59+
ToolParamConstraints,
60+
)
5461
from fenic.core.types.datatypes import DataType
5562
from fenic.core.types.schema import ColumnField, Schema
5663

@@ -869,6 +876,20 @@ def serialize_tool_parameter(
869876
"""Serialize a ToolParameter."""
870877
with self.path_context(field_name):
871878
try:
879+
c = tool_param.constraints
880+
if c is not None:
881+
constraints = ToolParameterConstraintsProto(
882+
gt=_to_numeric_constraint(c.gt) if c.gt is not None else None,
883+
ge=_to_numeric_constraint(c.ge) if c.ge is not None else None,
884+
lt=_to_numeric_constraint(c.lt) if c.lt is not None else None,
885+
le=_to_numeric_constraint(c.le) if c.le is not None else None,
886+
multiple_of=_to_numeric_constraint(c.multiple_of) if c.multiple_of is not None else None,
887+
min_length=c.min_length,
888+
max_length=c.max_length,
889+
pattern=c.pattern,
890+
)
891+
else:
892+
constraints = None
872893
allowed_values = None
873894
if tool_param.allowed_values:
874895
allowed_values = [
@@ -883,6 +904,8 @@ def serialize_tool_parameter(
883904
has_default=tool_param.has_default,
884905
default_value=self.serialize_scalar_value("default_value", tool_param.default_value),
885906
allowed_values=allowed_values,
907+
constraints=constraints,
908+
validator_names=[validator.name() for validator in tool_param.validators],
886909
)
887910
except Exception as e:
888911
self._handle_serde_error(e)
@@ -900,6 +923,20 @@ def deserialize_tool_parameter(
900923
allowed_values = [
901924
self.deserialize_scalar_value("allowed_values", allowed_value) for allowed_value in
902925
tool_param_proto.allowed_values]
926+
927+
constraints = None
928+
if tool_param_proto.constraints is not None:
929+
c = tool_param_proto.constraints
930+
constraints = ToolParamConstraints(
931+
gt=_from_numeric_constraint(c.gt) if c.HasField("gt") else None,
932+
ge=_from_numeric_constraint(c.ge) if c.HasField("ge") else None,
933+
lt=_from_numeric_constraint(c.lt) if c.HasField("lt") else None,
934+
le=_from_numeric_constraint(c.le) if c.HasField("le") else None,
935+
multiple_of=_from_numeric_constraint(c.multiple_of) if c.HasField("multiple_of") else None,
936+
min_length=c.min_length if c.HasField("min_length") else None,
937+
max_length=c.max_length if c.HasField("max_length") else None,
938+
pattern=c.pattern if c.HasField("pattern") else None,
939+
)
903940
return BoundToolParam(
904941
name=tool_param_proto.name,
905942
description=tool_param_proto.description,
@@ -908,6 +945,8 @@ def deserialize_tool_parameter(
908945
has_default=tool_param_proto.has_default,
909946
default_value=self.deserialize_scalar_value("default_value", tool_param_proto.default_value),
910947
allowed_values=allowed_values,
948+
constraints=constraints,
949+
validators=[get_param_validator(validator_name) for validator_name in tool_param_proto.validator_names],
911950
)
912951
except Exception as e:
913952
self._handle_serde_error(e)
@@ -989,3 +1028,20 @@ def pop(self) -> None:
9891028
def clear(self) -> None:
9901029
"""Clear the entire path stack."""
9911030
self._path_stack.clear()
1031+
1032+
def _to_numeric_constraint(value):
1033+
if isinstance(value, int):
1034+
return NumericConstraintProto(int_value=value)
1035+
if isinstance(value, float):
1036+
return NumericConstraintProto(float_value=value)
1037+
return None
1038+
1039+
def _from_numeric_constraint(nc: Optional[NumericConstraintProto]):
1040+
if nc is None:
1041+
return None
1042+
which = nc.WhichOneof("kind")
1043+
if which == "int_value":
1044+
return nc.int_value
1045+
if which == "float_value":
1046+
return nc.float_value
1047+
return None

src/fenic/core/_serde/proto/types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,18 @@
433433
from fenic._gen.protos.logical_plan.v1.plans_pb2 import (
434434
Unnest as UnnestProto,
435435
)
436+
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
437+
NumericConstraint as NumericConstraintProto,
438+
)
436439
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
437440
ToolDefinition as ToolDefinitionProto,
438441
)
439442
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
440443
ToolParameter as ToolParameterProto,
441444
)
445+
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
446+
ToolParameterConstraints as ToolParameterConstraintsProto,
447+
)
442448

443449
# Export all protobuf classes for easy importing
444450
__all__ = [
@@ -596,5 +602,7 @@
596602
"TableSinkProto",
597603
# Tools
598604
"ToolParameterProto",
599-
"ToolDefinitionProto"
605+
"ToolDefinitionProto",
606+
"ToolParameterConstraintsProto",
607+
"NumericConstraintProto",
600608
]

src/fenic/core/mcp/_server.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from functools import wraps
1616
from typing import Any, Callable, Dict, List, Optional, Union
1717

18-
from pydantic import BaseModel
18+
from pydantic import BaseModel, Field
1919
from typing_extensions import Annotated, Literal
2020

2121
from fenic.core._interfaces.session_state import BaseSessionState
@@ -36,6 +36,7 @@
3636
DynamicToolDefinition,
3737
ParameterizedToolDefinition,
3838
TableFormat,
39+
ToolParamConstraints,
3940
)
4041
from fenic.core.types.datatypes import ArrayType
4142
from fenic.logging import configure_logging
@@ -187,36 +188,32 @@ async def tool_fn_wrapper(*args, **kwargs) -> MCPResultSet:
187188
# Add one keyword-only parameter per tool param
188189
for param in tool.params:
189190
param_type = _type_for_param(param)
190-
param_annotation = _annotate_with_description(param_type, param.description)
191191
default_value = param.default_value if param.has_default else inspect._empty
192+
param_annotation = _annotate_with_description(param_type, param.description, param.constraints)
192193
params.append(
193194
inspect.Parameter(
194195
name=param.name,
195196
kind=inspect.Parameter.KEYWORD_ONLY,
196-
default=default_value,
197197
annotation=param_annotation,
198+
default=default_value,
198199
)
199200
)
200201
annotations[param.name] = param_annotation
201202

202203
# Add table_format and limit just like dynamic tools
203-
tf_ann = Annotated[TableFormat, (
204-
TABLE_FORMAT_DESCRIPTION
205-
)]
206-
lim_ann = Annotated[Optional[Union[str, int]], LIMIT_DESCRIPTION]
204+
tf_ann = Annotated[TableFormat, Field(description=TABLE_FORMAT_DESCRIPTION, default="markdown")]
205+
lim_ann = Annotated[Optional[Union[str, int]], Field(description=LIMIT_DESCRIPTION, gt=0, le=tool.max_result_limit, default=tool.max_result_limit)]
207206
params.append(
208207
inspect.Parameter(
209208
name="table_format",
210209
kind=inspect.Parameter.KEYWORD_ONLY,
211-
default="markdown",
212210
annotation=tf_ann,
213211
)
214212
)
215213
params.append(
216214
inspect.Parameter(
217215
name="limit",
218216
kind=inspect.Parameter.KEYWORD_ONLY,
219-
default=tool.max_result_limit,
220217
annotation=lim_ann,
221218
)
222219
)
@@ -355,10 +352,15 @@ def _type_for_param(p: BoundToolParam) -> type:
355352
base_py = Optional[base_py]
356353
return base_py
357354

358-
def _annotate_with_description(base_ann: type, description: Optional[str] = None):
359-
if description:
360-
return Annotated[base_ann, description]
361-
return base_ann
355+
def _annotate_with_description(
356+
py_type: type,
357+
description: Optional[str] = None,
358+
constraints: Optional[ToolParamConstraints] = None
359+
) -> Union[type, Annotated[type, Field]]:
360+
if description or constraints:
361+
constraints_dict = constraints.model_dump() if constraints else {}
362+
return Annotated[py_type, Field(description=description, **constraints_dict)]
363+
return py_type
362364

363365
def _render_markdown_preview(rows: List[Dict[str, Any]]) -> str:
364366
if not rows:

0 commit comments

Comments
 (0)