|
4 | 4 |
|
5 | 5 | import inspect
|
6 | 6 | import re
|
7 |
| -import types |
8 | 7 | from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
|
9 | 8 | from contextlib import (
|
10 | 9 | AbstractAsyncContextManager,
|
11 | 10 | asynccontextmanager,
|
12 | 11 | )
|
13 | 12 | from itertools import chain
|
14 |
| -from typing import Any, Generic, Literal, TypeVar, Union, get_args, get_origin |
| 13 | +from typing import Any, Generic, Literal |
15 | 14 |
|
16 | 15 | import anyio
|
17 | 16 | import pydantic_core
|
18 |
| -from pydantic import BaseModel, Field, ValidationError |
19 |
| -from pydantic.fields import FieldInfo |
| 17 | +from pydantic import BaseModel, Field |
20 | 18 | from pydantic.networks import AnyUrl
|
21 | 19 | from pydantic_settings import BaseSettings, SettingsConfigDict
|
22 | 20 | from starlette.applications import Starlette
|
|
36 | 34 | from mcp.server.auth.settings import (
|
37 | 35 | AuthSettings,
|
38 | 36 | )
|
| 37 | +from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation |
39 | 38 | from mcp.server.fastmcp.exceptions import ResourceError
|
40 | 39 | from mcp.server.fastmcp.prompts import Prompt, PromptManager
|
41 | 40 | from mcp.server.fastmcp.resources import FunctionResource, Resource, ResourceManager
|
|
67 | 66 |
|
68 | 67 | logger = get_logger(__name__)
|
69 | 68 |
|
70 |
| -ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel) |
71 |
| - |
72 |
| - |
73 |
| -class ElicitationResult(BaseModel, Generic[ElicitSchemaModelT]): |
74 |
| - """Result of an elicitation request.""" |
75 |
| - |
76 |
| - action: Literal["accept", "decline", "cancel"] |
77 |
| - """The user's action in response to the elicitation.""" |
78 |
| - |
79 |
| - data: ElicitSchemaModelT | None = None |
80 |
| - """The validated data if action is 'accept', None otherwise.""" |
81 |
| - |
82 |
| - validation_error: str | None = None |
83 |
| - """Validation error message if data failed to validate.""" |
84 |
| - |
85 | 69 |
|
86 | 70 | class Settings(BaseSettings, Generic[LifespanResultT]):
|
87 | 71 | """FastMCP server settings.
|
@@ -875,43 +859,6 @@ def _convert_to_content(
|
875 | 859 | return [TextContent(type="text", text=result)]
|
876 | 860 |
|
877 | 861 |
|
878 |
| -# Primitive types allowed in elicitation schemas |
879 |
| -_ELICITATION_PRIMITIVE_TYPES = (str, int, float, bool) |
880 |
| - |
881 |
| - |
882 |
| -def _validate_elicitation_schema(schema: type[BaseModel]) -> None: |
883 |
| - """Validate that a Pydantic model only contains primitive field types.""" |
884 |
| - for field_name, field_info in schema.model_fields.items(): |
885 |
| - if not _is_primitive_field(field_info): |
886 |
| - raise TypeError( |
887 |
| - f"Elicitation schema field '{field_name}' must be a primitive type " |
888 |
| - f"{_ELICITATION_PRIMITIVE_TYPES} or Optional of these types. " |
889 |
| - f"Complex types like lists, dicts, or nested models are not allowed." |
890 |
| - ) |
891 |
| - |
892 |
| - |
893 |
| -def _is_primitive_field(field_info: FieldInfo) -> bool: |
894 |
| - """Check if a field is a primitive type allowed in elicitation schemas.""" |
895 |
| - annotation = field_info.annotation |
896 |
| - |
897 |
| - # Handle None type |
898 |
| - if annotation is types.NoneType: |
899 |
| - return True |
900 |
| - |
901 |
| - # Handle basic primitive types |
902 |
| - if annotation in _ELICITATION_PRIMITIVE_TYPES: |
903 |
| - return True |
904 |
| - |
905 |
| - # Handle Union types |
906 |
| - origin = get_origin(annotation) |
907 |
| - if origin is Union or origin is types.UnionType: |
908 |
| - args = get_args(annotation) |
909 |
| - # All args must be primitive types or None |
910 |
| - return all(arg is types.NoneType or arg in _ELICITATION_PRIMITIVE_TYPES for arg in args) |
911 |
| - |
912 |
| - return False |
913 |
| - |
914 |
| - |
915 | 862 | class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
|
916 | 863 | """Context object providing access to MCP capabilities.
|
917 | 864 |
|
@@ -1035,27 +982,10 @@ async def elicit(
|
1035 | 982 | The result.data will only be populated if action is "accept" and validation succeeded.
|
1036 | 983 | """
|
1037 | 984 |
|
1038 |
| - # Validate that schema only contains primitive types and fail loudly if not |
1039 |
| - _validate_elicitation_schema(schema) |
1040 |
| - |
1041 |
| - json_schema = schema.model_json_schema() |
1042 |
| - |
1043 |
| - result = await self.request_context.session.elicit( |
1044 |
| - message=message, |
1045 |
| - requestedSchema=json_schema, |
1046 |
| - related_request_id=self.request_id, |
| 985 | + return await elicit_with_validation( |
| 986 | + session=self.request_context.session, message=message, schema=schema, related_request_id=self.request_id |
1047 | 987 | )
|
1048 | 988 |
|
1049 |
| - if result.action == "accept" and result.content: |
1050 |
| - # Validate and parse the content using the schema |
1051 |
| - try: |
1052 |
| - validated_data = schema.model_validate(result.content) |
1053 |
| - return ElicitationResult(action="accept", data=validated_data) |
1054 |
| - except ValidationError as e: |
1055 |
| - return ElicitationResult(action="accept", validation_error=str(e)) |
1056 |
| - else: |
1057 |
| - return ElicitationResult(action=result.action) |
1058 |
| - |
1059 | 989 | async def log(
|
1060 | 990 | self,
|
1061 | 991 | level: Literal["debug", "info", "warning", "error"],
|
|
0 commit comments