Skip to content

Commit 3ea2993

Browse files
committed
feat: allow users to add constraints and validators to ToolParam
1 parent 5d9d16a commit 3ea2993

File tree

10 files changed

+458
-44
lines changed

10 files changed

+458
-44
lines changed

protos/logical_plan/v1/tools.proto

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,24 @@ import "logical_plan/v1/datatypes.proto";
66
import "logical_plan/v1/complex_types.proto";
77
import "logical_plan/v1/plans.proto";
88

9+
message NumericConstraint {
10+
oneof kind {
11+
sint32 int_value = 1;
12+
float float_value = 2;
13+
}
14+
}
15+
16+
message ToolParameterConstraints {
17+
optional NumericConstraint gt = 1;
18+
optional NumericConstraint ge = 2;
19+
optional NumericConstraint lt = 3;
20+
optional NumericConstraint le = 4;
21+
optional NumericConstraint multiple_of = 5;
22+
23+
optional uint32 min_length = 6;
24+
optional uint32 max_length = 7;
25+
optional string pattern = 8;
26+
}
927

1028
message ToolParameter {
1129
string name = 1;
@@ -15,6 +33,8 @@ message ToolParameter {
1533
bool has_default = 5;
1634
optional ScalarValue default_value = 6;
1735
repeated ScalarValue allowed_values = 7;
36+
optional ToolParameterConstraints constraints = 8;
37+
repeated string validator_names = 9;
1838
}
1939

2040
message ToolDefinition {

src/fenic/_gen/protos/logical_plan/v1/tools_pb2.py

Lines changed: 9 additions & 5 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/fenic/_gen/protos/logical_plan/v1/tools_pb2.pyi

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,55 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map
88

99
DESCRIPTOR: _descriptor.FileDescriptor
1010

11+
class NumericConstraint(_message.Message):
12+
__slots__ = ("int_value", "float_value")
13+
INT_VALUE_FIELD_NUMBER: _ClassVar[int]
14+
FLOAT_VALUE_FIELD_NUMBER: _ClassVar[int]
15+
int_value: int
16+
float_value: float
17+
def __init__(self, int_value: _Optional[int] = ..., float_value: _Optional[float] = ...) -> None: ...
18+
19+
class ToolParameterConstraints(_message.Message):
20+
__slots__ = ("gt", "ge", "lt", "le", "multiple_of", "min_length", "max_length", "pattern")
21+
GT_FIELD_NUMBER: _ClassVar[int]
22+
GE_FIELD_NUMBER: _ClassVar[int]
23+
LT_FIELD_NUMBER: _ClassVar[int]
24+
LE_FIELD_NUMBER: _ClassVar[int]
25+
MULTIPLE_OF_FIELD_NUMBER: _ClassVar[int]
26+
MIN_LENGTH_FIELD_NUMBER: _ClassVar[int]
27+
MAX_LENGTH_FIELD_NUMBER: _ClassVar[int]
28+
PATTERN_FIELD_NUMBER: _ClassVar[int]
29+
gt: NumericConstraint
30+
ge: NumericConstraint
31+
lt: NumericConstraint
32+
le: NumericConstraint
33+
multiple_of: NumericConstraint
34+
min_length: int
35+
max_length: int
36+
pattern: str
37+
def __init__(self, gt: _Optional[_Union[NumericConstraint, _Mapping]] = ..., ge: _Optional[_Union[NumericConstraint, _Mapping]] = ..., lt: _Optional[_Union[NumericConstraint, _Mapping]] = ..., le: _Optional[_Union[NumericConstraint, _Mapping]] = ..., multiple_of: _Optional[_Union[NumericConstraint, _Mapping]] = ..., min_length: _Optional[int] = ..., max_length: _Optional[int] = ..., pattern: _Optional[str] = ...) -> None: ...
38+
1139
class ToolParameter(_message.Message):
12-
__slots__ = ("name", "description", "data_type", "required", "has_default", "default_value", "allowed_values")
40+
__slots__ = ("name", "description", "data_type", "required", "has_default", "default_value", "allowed_values", "constraints", "validator_names")
1341
NAME_FIELD_NUMBER: _ClassVar[int]
1442
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
1543
DATA_TYPE_FIELD_NUMBER: _ClassVar[int]
1644
REQUIRED_FIELD_NUMBER: _ClassVar[int]
1745
HAS_DEFAULT_FIELD_NUMBER: _ClassVar[int]
1846
DEFAULT_VALUE_FIELD_NUMBER: _ClassVar[int]
1947
ALLOWED_VALUES_FIELD_NUMBER: _ClassVar[int]
48+
CONSTRAINTS_FIELD_NUMBER: _ClassVar[int]
49+
VALIDATOR_NAMES_FIELD_NUMBER: _ClassVar[int]
2050
name: str
2151
description: str
2252
data_type: _datatypes_pb2.DataType
2353
required: bool
2454
has_default: bool
2555
default_value: _complex_types_pb2.ScalarValue
2656
allowed_values: _containers.RepeatedCompositeFieldContainer[_complex_types_pb2.ScalarValue]
27-
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., data_type: _Optional[_Union[_datatypes_pb2.DataType, _Mapping]] = ..., required: bool = ..., has_default: bool = ..., default_value: _Optional[_Union[_complex_types_pb2.ScalarValue, _Mapping]] = ..., allowed_values: _Optional[_Iterable[_Union[_complex_types_pb2.ScalarValue, _Mapping]]] = ...) -> None: ...
57+
constraints: ToolParameterConstraints
58+
validator_names: _containers.RepeatedScalarFieldContainer[str]
59+
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., data_type: _Optional[_Union[_datatypes_pb2.DataType, _Mapping]] = ..., required: bool = ..., has_default: bool = ..., default_value: _Optional[_Union[_complex_types_pb2.ScalarValue, _Mapping]] = ..., allowed_values: _Optional[_Iterable[_Union[_complex_types_pb2.ScalarValue, _Mapping]]] = ..., constraints: _Optional[_Union[ToolParameterConstraints, _Mapping]] = ..., validator_names: _Optional[_Iterable[str]] = ...) -> None: ...
2860

2961
class ToolDefinition(_message.Message):
3062
__slots__ = ("name", "description", "params", "parameterized_view", "result_limit")

src/fenic/core/_serde/proto/serde_context.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
FenicSchemaProto,
4040
LogicalExprProto,
4141
LogicalPlanProto,
42+
NumericConstraintProto,
4243
NumpyArrayProto,
4344
ResolvedClassDefinitionProto,
4445
ResolvedModelAliasProto,
@@ -48,12 +49,18 @@
4849
ScalarStructProto,
4950
ScalarValueProto,
5051
ToolDefinitionProto,
52+
ToolParameterConstraintsProto,
5153
ToolParameterProto,
5254
)
5355
from fenic.core._utils.structured_outputs import (
5456
check_if_model_uses_unserializable_features,
5557
)
56-
from fenic.core.mcp.types import BoundToolParam, ParameterizedToolDefinition
58+
from fenic.core.mcp._validators import get_param_validator
59+
from fenic.core.mcp.types import (
60+
BoundToolParam,
61+
ParameterizedToolDefinition,
62+
ToolParamConstraints,
63+
)
5764
from fenic.core.types.datatypes import DataType
5865
from fenic.core.types.schema import ColumnField, Schema
5966

@@ -873,6 +880,20 @@ def serialize_tool_parameter(
873880
"""Serialize a ToolParameter."""
874881
with self.path_context(field_name):
875882
try:
883+
c = tool_param.constraints
884+
if c is not None:
885+
constraints = ToolParameterConstraintsProto(
886+
gt=_to_numeric_constraint(c.gt) if c.gt is not None else None,
887+
ge=_to_numeric_constraint(c.ge) if c.ge is not None else None,
888+
lt=_to_numeric_constraint(c.lt) if c.lt is not None else None,
889+
le=_to_numeric_constraint(c.le) if c.le is not None else None,
890+
multiple_of=_to_numeric_constraint(c.multiple_of) if c.multiple_of is not None else None,
891+
min_length=c.min_length,
892+
max_length=c.max_length,
893+
pattern=c.pattern,
894+
)
895+
else:
896+
constraints = None
876897
allowed_values = None
877898
if tool_param.allowed_values:
878899
allowed_values = [
@@ -887,6 +908,8 @@ def serialize_tool_parameter(
887908
has_default=tool_param.has_default,
888909
default_value=self.serialize_scalar_value("default_value", tool_param.default_value),
889910
allowed_values=allowed_values,
911+
constraints=constraints,
912+
validator_names=[validator.name() for validator in tool_param.validators],
890913
)
891914
except Exception as e:
892915
self._handle_serde_error(e)
@@ -904,6 +927,20 @@ def deserialize_tool_parameter(
904927
allowed_values = [
905928
self.deserialize_scalar_value("allowed_values", allowed_value) for allowed_value in
906929
tool_param_proto.allowed_values]
930+
931+
constraints = None
932+
if tool_param_proto.constraints is not None:
933+
c = tool_param_proto.constraints
934+
constraints = ToolParamConstraints(
935+
gt=_from_numeric_constraint(c.gt) if c.HasField("gt") else None,
936+
ge=_from_numeric_constraint(c.ge) if c.HasField("ge") else None,
937+
lt=_from_numeric_constraint(c.lt) if c.HasField("lt") else None,
938+
le=_from_numeric_constraint(c.le) if c.HasField("le") else None,
939+
multiple_of=_from_numeric_constraint(c.multiple_of) if c.HasField("multiple_of") else None,
940+
min_length=c.min_length if c.HasField("min_length") else None,
941+
max_length=c.max_length if c.HasField("max_length") else None,
942+
pattern=c.pattern if c.HasField("pattern") else None,
943+
)
907944
return BoundToolParam(
908945
name=tool_param_proto.name,
909946
description=tool_param_proto.description,
@@ -912,6 +949,8 @@ def deserialize_tool_parameter(
912949
has_default=tool_param_proto.has_default,
913950
default_value=self.deserialize_scalar_value("default_value", tool_param_proto.default_value),
914951
allowed_values=allowed_values,
952+
constraints=constraints,
953+
validators=[get_param_validator(validator_name) for validator_name in tool_param_proto.validator_names],
915954
)
916955
except Exception as e:
917956
self._handle_serde_error(e)
@@ -993,3 +1032,20 @@ def pop(self) -> None:
9931032
def clear(self) -> None:
9941033
"""Clear the entire path stack."""
9951034
self._path_stack.clear()
1035+
1036+
def _to_numeric_constraint(value):
1037+
if isinstance(value, int):
1038+
return NumericConstraintProto(int_value=value)
1039+
if isinstance(value, float):
1040+
return NumericConstraintProto(float_value=value)
1041+
return None
1042+
1043+
def _from_numeric_constraint(nc: Optional[NumericConstraintProto]):
1044+
if nc is None:
1045+
return None
1046+
which = nc.WhichOneof("kind")
1047+
if which == "int_value":
1048+
return nc.int_value
1049+
if which == "float_value":
1050+
return nc.float_value
1051+
return None

src/fenic/core/_serde/proto/types.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,12 +433,18 @@
433433
from fenic._gen.protos.logical_plan.v1.plans_pb2 import (
434434
Unnest as UnnestProto,
435435
)
436+
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
437+
NumericConstraint as NumericConstraintProto,
438+
)
436439
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
437440
ToolDefinition as ToolDefinitionProto,
438441
)
439442
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
440443
ToolParameter as ToolParameterProto,
441444
)
445+
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
446+
ToolParameterConstraints as ToolParameterConstraintsProto,
447+
)
442448

443449
# Export all protobuf classes for easy importing
444450
__all__ = [
@@ -596,5 +602,7 @@
596602
"TableSinkProto",
597603
# Tools
598604
"ToolParameterProto",
599-
"ToolDefinitionProto"
605+
"ToolDefinitionProto",
606+
"ToolParameterConstraintsProto",
607+
"NumericConstraintProto",
600608
]

src/fenic/core/mcp/_server.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Any, Callable, Dict, List, Optional, Union
1717

1818
import polars as pl
19-
from pydantic import BaseModel, ConfigDict
19+
from pydantic import BaseModel, ConfigDict, Field
2020
from typing_extensions import Annotated, Literal
2121

2222
from fenic.core._interfaces.session_state import BaseSessionState
@@ -37,6 +37,7 @@
3737
DynamicToolDefinition,
3838
ParameterizedToolDefinition,
3939
TableFormat,
40+
ToolParamConstraints,
4041
)
4142
from fenic.core.types.datatypes import ArrayType
4243
from fenic.logging import configure_logging
@@ -212,36 +213,32 @@ async def tool_fn_wrapper(*args, **kwargs) -> MCPResultSet:
212213
# Add one keyword-only parameter per tool param
213214
for param in tool.params:
214215
param_type = _type_for_param(param)
215-
param_annotation = _annotate_with_description(param_type, param.description)
216216
default_value = param.default_value if param.has_default else inspect._empty
217+
param_annotation = _annotate_with_description(param_type, param.description, param.constraints)
217218
params.append(
218219
inspect.Parameter(
219220
name=param.name,
220221
kind=inspect.Parameter.KEYWORD_ONLY,
221-
default=default_value,
222222
annotation=param_annotation,
223+
default=default_value,
223224
)
224225
)
225226
annotations[param.name] = param_annotation
226227

227228
# Add table_format and limit just like dynamic tools
228-
tf_ann = Annotated[TableFormat, (
229-
TABLE_FORMAT_DESCRIPTION
230-
)]
231-
lim_ann = Annotated[Optional[Union[str, int]], LIMIT_DESCRIPTION]
229+
tf_ann = Annotated[TableFormat, Field(description=TABLE_FORMAT_DESCRIPTION, default="markdown")]
230+
lim_ann = Annotated[Optional[Union[str, int]], Field(description=LIMIT_DESCRIPTION, gt=0, le=tool.max_result_limit, default=tool.max_result_limit)]
232231
params.append(
233232
inspect.Parameter(
234233
name="table_format",
235234
kind=inspect.Parameter.KEYWORD_ONLY,
236-
default="markdown",
237235
annotation=tf_ann,
238236
)
239237
)
240238
params.append(
241239
inspect.Parameter(
242240
name="limit",
243241
kind=inspect.Parameter.KEYWORD_ONLY,
244-
default=tool.max_result_limit,
245242
annotation=lim_ann,
246243
)
247244
)
@@ -376,10 +373,15 @@ def _type_for_param(p: BoundToolParam) -> type:
376373
base_py = Optional[base_py]
377374
return base_py
378375

379-
def _annotate_with_description(base_ann: type, description: Optional[str] = None):
380-
if description:
381-
return Annotated[base_ann, description]
382-
return base_ann
376+
def _annotate_with_description(
377+
py_type: type,
378+
description: Optional[str] = None,
379+
constraints: Optional[ToolParamConstraints] = None
380+
) -> Union[type, Annotated[type, Field]]:
381+
if description or constraints:
382+
constraints_dict = constraints.model_dump() if constraints else {}
383+
return Annotated[py_type, Field(description=description, **constraints_dict)]
384+
return py_type
383385

384386
def _render_markdown_preview(rows: List[Dict[str, Any]]) -> str:
385387
if not rows:

0 commit comments

Comments
 (0)