Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions protos/logical_plan/v1/tools.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,24 @@ import "logical_plan/v1/datatypes.proto";
import "logical_plan/v1/complex_types.proto";
import "logical_plan/v1/plans.proto";

message NumericConstraint {
oneof kind {
sint32 int_value = 1;
float float_value = 2;
}
}

message ToolParameterConstraints {
optional NumericConstraint gt = 1;
optional NumericConstraint ge = 2;
optional NumericConstraint lt = 3;
optional NumericConstraint le = 4;
optional NumericConstraint multiple_of = 5;

optional uint32 min_length = 6;
optional uint32 max_length = 7;
optional string pattern = 8;
}

message ToolParameter {
string name = 1;
Expand All @@ -15,6 +33,8 @@ message ToolParameter {
bool has_default = 5;
optional ScalarValue default_value = 6;
repeated ScalarValue allowed_values = 7;
optional ToolParameterConstraints constraints = 8;
repeated string validator_names = 9;
}

message ToolDefinition {
Expand Down
14 changes: 9 additions & 5 deletions src/fenic/_gen/protos/logical_plan/v1/tools_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

36 changes: 34 additions & 2 deletions src/fenic/_gen/protos/logical_plan/v1/tools_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,55 @@ from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Map

DESCRIPTOR: _descriptor.FileDescriptor

class NumericConstraint(_message.Message):
__slots__ = ("int_value", "float_value")
INT_VALUE_FIELD_NUMBER: _ClassVar[int]
FLOAT_VALUE_FIELD_NUMBER: _ClassVar[int]
int_value: int
float_value: float
def __init__(self, int_value: _Optional[int] = ..., float_value: _Optional[float] = ...) -> None: ...

class ToolParameterConstraints(_message.Message):
__slots__ = ("gt", "ge", "lt", "le", "multiple_of", "min_length", "max_length", "pattern")
GT_FIELD_NUMBER: _ClassVar[int]
GE_FIELD_NUMBER: _ClassVar[int]
LT_FIELD_NUMBER: _ClassVar[int]
LE_FIELD_NUMBER: _ClassVar[int]
MULTIPLE_OF_FIELD_NUMBER: _ClassVar[int]
MIN_LENGTH_FIELD_NUMBER: _ClassVar[int]
MAX_LENGTH_FIELD_NUMBER: _ClassVar[int]
PATTERN_FIELD_NUMBER: _ClassVar[int]
gt: NumericConstraint
ge: NumericConstraint
lt: NumericConstraint
le: NumericConstraint
multiple_of: NumericConstraint
min_length: int
max_length: int
pattern: str
def __init__(self, gt: _Optional[_Union[NumericConstraint, _Mapping]] = ..., ge: _Optional[_Union[NumericConstraint, _Mapping]] = ..., lt: _Optional[_Union[NumericConstraint, _Mapping]] = ..., le: _Optional[_Union[NumericConstraint, _Mapping]] = ..., multiple_of: _Optional[_Union[NumericConstraint, _Mapping]] = ..., min_length: _Optional[int] = ..., max_length: _Optional[int] = ..., pattern: _Optional[str] = ...) -> None: ...

class ToolParameter(_message.Message):
__slots__ = ("name", "description", "data_type", "required", "has_default", "default_value", "allowed_values")
__slots__ = ("name", "description", "data_type", "required", "has_default", "default_value", "allowed_values", "constraints", "validator_names")
NAME_FIELD_NUMBER: _ClassVar[int]
DESCRIPTION_FIELD_NUMBER: _ClassVar[int]
DATA_TYPE_FIELD_NUMBER: _ClassVar[int]
REQUIRED_FIELD_NUMBER: _ClassVar[int]
HAS_DEFAULT_FIELD_NUMBER: _ClassVar[int]
DEFAULT_VALUE_FIELD_NUMBER: _ClassVar[int]
ALLOWED_VALUES_FIELD_NUMBER: _ClassVar[int]
CONSTRAINTS_FIELD_NUMBER: _ClassVar[int]
VALIDATOR_NAMES_FIELD_NUMBER: _ClassVar[int]
name: str
description: str
data_type: _datatypes_pb2.DataType
required: bool
has_default: bool
default_value: _complex_types_pb2.ScalarValue
allowed_values: _containers.RepeatedCompositeFieldContainer[_complex_types_pb2.ScalarValue]
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., data_type: _Optional[_Union[_datatypes_pb2.DataType, _Mapping]] = ..., required: bool = ..., has_default: bool = ..., default_value: _Optional[_Union[_complex_types_pb2.ScalarValue, _Mapping]] = ..., allowed_values: _Optional[_Iterable[_Union[_complex_types_pb2.ScalarValue, _Mapping]]] = ...) -> None: ...
constraints: ToolParameterConstraints
validator_names: _containers.RepeatedScalarFieldContainer[str]
def __init__(self, name: _Optional[str] = ..., description: _Optional[str] = ..., data_type: _Optional[_Union[_datatypes_pb2.DataType, _Mapping]] = ..., required: bool = ..., has_default: bool = ..., default_value: _Optional[_Union[_complex_types_pb2.ScalarValue, _Mapping]] = ..., allowed_values: _Optional[_Iterable[_Union[_complex_types_pb2.ScalarValue, _Mapping]]] = ..., constraints: _Optional[_Union[ToolParameterConstraints, _Mapping]] = ..., validator_names: _Optional[_Iterable[str]] = ...) -> None: ...

class ToolDefinition(_message.Message):
__slots__ = ("name", "description", "params", "parameterized_view", "result_limit")
Expand Down
58 changes: 57 additions & 1 deletion src/fenic/core/_serde/proto/serde_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
FenicSchemaProto,
LogicalExprProto,
LogicalPlanProto,
NumericConstraintProto,
NumpyArrayProto,
ResolvedClassDefinitionProto,
ResolvedModelAliasProto,
Expand All @@ -49,12 +50,18 @@
ScalarStructProto,
ScalarValueProto,
ToolDefinitionProto,
ToolParameterConstraintsProto,
ToolParameterProto,
)
from fenic.core._utils.structured_outputs import (
check_if_model_uses_unserializable_features,
)
from fenic.core.mcp.types import BoundToolParam, UserDefinedTool
from fenic.core.mcp._validators import get_param_validator
from fenic.core.mcp.types import (
BoundToolParam,
ToolParamConstraints,
UserDefinedTool,
)
from fenic.core.types.datatypes import DataType
from fenic.core.types.schema import ColumnField, Schema

Expand Down Expand Up @@ -884,6 +891,20 @@ def serialize_tool_parameter(
"""Serialize a ToolParameter."""
with self.path_context(field_name):
try:
c = tool_param.constraints
if c is not None:
constraints = ToolParameterConstraintsProto(
gt=_to_numeric_constraint(c.gt) if c.gt is not None else None,
ge=_to_numeric_constraint(c.ge) if c.ge is not None else None,
lt=_to_numeric_constraint(c.lt) if c.lt is not None else None,
le=_to_numeric_constraint(c.le) if c.le is not None else None,
multiple_of=_to_numeric_constraint(c.multiple_of) if c.multiple_of is not None else None,
min_length=c.min_length,
max_length=c.max_length,
pattern=c.pattern,
)
else:
constraints = None
allowed_values = None
if tool_param.allowed_values:
allowed_values = [
Expand All @@ -898,6 +919,8 @@ def serialize_tool_parameter(
has_default=tool_param.has_default,
default_value=self.serialize_scalar_value("default_value", tool_param.default_value),
allowed_values=allowed_values,
constraints=constraints,
validator_names=[validator.name() for validator in tool_param.validators],
)
except Exception as e:
self._handle_serde_error(e)
Expand All @@ -915,6 +938,20 @@ def deserialize_tool_parameter(
allowed_values = [
self.deserialize_scalar_value("allowed_values", allowed_value) for allowed_value in
tool_param_proto.allowed_values]

constraints = None
if tool_param_proto.constraints is not None:
c = tool_param_proto.constraints
constraints = ToolParamConstraints(
gt=_from_numeric_constraint(c.gt) if c.HasField("gt") else None,
ge=_from_numeric_constraint(c.ge) if c.HasField("ge") else None,
lt=_from_numeric_constraint(c.lt) if c.HasField("lt") else None,
le=_from_numeric_constraint(c.le) if c.HasField("le") else None,
multiple_of=_from_numeric_constraint(c.multiple_of) if c.HasField("multiple_of") else None,
min_length=c.min_length if c.HasField("min_length") else None,
max_length=c.max_length if c.HasField("max_length") else None,
pattern=c.pattern if c.HasField("pattern") else None,
)
return BoundToolParam(
name=tool_param_proto.name,
description=tool_param_proto.description,
Expand All @@ -923,6 +960,8 @@ def deserialize_tool_parameter(
has_default=tool_param_proto.has_default,
default_value=self.deserialize_scalar_value("default_value", tool_param_proto.default_value),
allowed_values=allowed_values,
constraints=constraints,
validators=[get_param_validator(validator_name) for validator_name in tool_param_proto.validator_names],
)
except Exception as e:
self._handle_serde_error(e)
Expand Down Expand Up @@ -1004,3 +1043,20 @@ def pop(self) -> None:
def clear(self) -> None:
"""Clear the entire path stack."""
self._path_stack.clear()

def _to_numeric_constraint(value):
if isinstance(value, int):
return NumericConstraintProto(int_value=value)
if isinstance(value, float):
return NumericConstraintProto(float_value=value)
return None

def _from_numeric_constraint(nc: Optional[NumericConstraintProto]):
if nc is None:
return None
which = nc.WhichOneof("kind")
if which == "int_value":
return nc.int_value
if which == "float_value":
return nc.float_value
return None
9 changes: 9 additions & 0 deletions src/fenic/core/_serde/proto/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,12 +574,18 @@
from fenic._gen.protos.logical_plan.v1.plans_pb2 import (
Unnest as UnnestProto,
)
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
NumericConstraint as NumericConstraintProto,
)
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
ToolDefinition as ToolDefinitionProto,
)
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
ToolParameter as ToolParameterProto,
)
from fenic._gen.protos.logical_plan.v1.tools_pb2 import (
ToolParameterConstraints as ToolParameterConstraintsProto,
)

# Export all protobuf classes for easy importing
__all__ = [
Expand Down Expand Up @@ -767,6 +773,9 @@
# Tools
"ToolParameterProto",
"ToolDefinitionProto",
"ToolParameterConstraintsProto",
"NumericConstraintProto",
"ToolDefinitionProto",
# Date time related classes
"YearExprProto",
"MonthExprProto",
Expand Down
41 changes: 29 additions & 12 deletions src/fenic/core/mcp/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union

import polars as pl
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, Field
from typing_extensions import Annotated, Literal

from fenic.core._interfaces.session_state import BaseSessionState
Expand All @@ -37,6 +37,7 @@
from fenic.core.mcp.types import (
SystemTool,
TableFormat,
ToolParamConstraints,
UserDefinedTool,
)
from fenic.core.types.datatypes import ArrayType
Expand Down Expand Up @@ -216,28 +217,39 @@ async def tool_fn_wrapper(*args, **kwargs) -> MCPResultSet:
# Add one keyword-only parameter per tool param
for param in tool_definition.params:
param_type = _type_for_param(param)
param_annotation = _annotate_with_description(param_type, param.description)
default_value = param.default_value if param.has_default else inspect._empty
param_annotation = _annotate_with_description(param_type, param.description, param.constraints)
params.append(
inspect.Parameter(
name=param.name,
kind=inspect.Parameter.KEYWORD_ONLY,
default=default_value,
annotation=param_annotation,
default=default_value,
)
)
annotations[param.name] = param_annotation

# Add table_format and limit just like system tools
tf_ann = Annotated[TableFormat, (
TABLE_FORMAT_DESCRIPTION
)]
lim_ann = Annotated[Optional[Union[str, int]], LIMIT_DESCRIPTION]
tf_ann = Annotated[
TableFormat,
Field(
description=TABLE_FORMAT_DESCRIPTION,
default="markdown"
)
]
lim_ann = Annotated[
Optional[Union[str, int]],
Field(
description=LIMIT_DESCRIPTION,
gt=0,
le=tool_definition.max_result_limit,
default=tool_definition.max_result_limit
)
]
params.append(
inspect.Parameter(
name="table_format",
kind=inspect.Parameter.KEYWORD_ONLY,
default="markdown",
annotation=tf_ann,
)
)
Expand Down Expand Up @@ -384,10 +396,15 @@ def _type_for_param(p: BoundToolParam) -> type:
base_py = Optional[base_py]
return base_py

def _annotate_with_description(base_ann: type, description: Optional[str] = None):
if description:
return Annotated[base_ann, description]
return base_ann
def _annotate_with_description(
py_type: type,
description: Optional[str] = None,
constraints: Optional[ToolParamConstraints] = None
) -> Union[type, Annotated[type, Field]]:
if description or constraints:
constraints_dict = constraints.model_dump() if constraints else {}
return Annotated[py_type, Field(description=description, **constraints_dict)]
return py_type

def _render_markdown_preview(rows: List[Dict[str, Any]]) -> str:
if not rows:
Expand Down
Loading