diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..df7f0040c3 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: minor + +TODO diff --git a/docs/integrations/pydantic.md b/docs/integrations/pydantic.md index e1c1967772..96e68dcc6b 100644 --- a/docs/integrations/pydantic.md +++ b/docs/integrations/pydantic.md @@ -1,85 +1,490 @@ --- title: Pydantic support -experimental: true --- # Pydantic support -Strawberry comes with support for -[Pydantic](https://pydantic-docs.helpmanual.io/). This allows for the creation -of Strawberry types from pydantic models without having to write code twice. +Strawberry provides first-class support for [Pydantic](https://pydantic.dev/) +models, allowing you to directly decorate your Pydantic `BaseModel` classes to +create GraphQL types without writing code twice. -Here's a basic example of how this works, let's say we have a pydantic Model for -a user, like this: +## Installation + +```bash +pip install strawberry-graphql[pydantic] +``` + +## Basic Usage + +The simplest way to use Pydantic with Strawberry is to decorate your Pydantic +models directly: ```python -from datetime import datetime -from typing import List, Optional +import strawberry from pydantic import BaseModel +@strawberry.pydantic.type class User(BaseModel): id: int name: str - signup_ts: Optional[datetime] = None - friends: List[int] = [] + email: str + + +@strawberry.type +class Query: + @strawberry.field + def get_user(self) -> User: + return User(id=1, name="John", email="john@example.com") + + +schema = strawberry.Schema(query=Query) ``` -We can create a Strawberry type by using the -`strawberry.experimental.pydantic.type` decorator: +This automatically creates a GraphQL type that includes all fields from your +Pydantic model. + +## Type Decorators + +### `@strawberry.pydantic.type` + +Creates a GraphQL object type from a Pydantic model: + +```python +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int + is_active: bool = True +``` + +### `@strawberry.pydantic.input` + +Creates a GraphQL input type from a Pydantic model: + +```python +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + age: int + email: str + + +@strawberry.type +class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) +``` + +### `@strawberry.pydantic.interface` + +Creates a GraphQL interface from a Pydantic model: + +```python +@strawberry.pydantic.interface +class Node(BaseModel): + id: str + + +@strawberry.pydantic.type +class User(BaseModel): + id: str + name: str + # User implements Node interface +``` + +## Configuration Options + +All decorators accept optional configuration parameters: + +```python +@strawberry.pydantic.type( + name="CustomUser", # Override the GraphQL type name + description="A user in the system", # Add type description +) +class User(BaseModel): + name: str = Field(alias="fullName") + age: int +``` + +## Field Features + +### Field Descriptions + +Pydantic field descriptions are automatically preserved in the GraphQL schema: + +```python +from pydantic import Field + + +@strawberry.pydantic.type +class User(BaseModel): + name: str = Field(description="The user's full name") + age: int = Field(description="The user's age in years") +``` + +### Field Aliases + +Pydantic field aliases are automatically used as GraphQL field names: + +```python +@strawberry.pydantic.type +class User(BaseModel): + name: str = Field(alias="fullName") + age: int = Field(alias="yearsOld") +``` + +### Optional Fields + +Pydantic optional fields are properly handled: + +```python +from typing import Optional + + +@strawberry.pydantic.type +class User(BaseModel): + name: str + email: Optional[str] = None + age: Optional[int] = None +``` + +### Private Fields + +You can use `strawberry.Private` to mark fields that should not be exposed in +the GraphQL schema but are still accessible in your Python code: ```python import strawberry -from .models import User +@strawberry.pydantic.type +class User(BaseModel): + id: int + name: str + password: strawberry.Private[str] # Not exposed in GraphQL + email: str +``` -@strawberry.experimental.pydantic.type(model=User) -class UserType: - id: strawberry.auto - name: strawberry.auto - friends: strawberry.auto +This generates a GraphQL schema with only the public fields: + +```graphql +type User { + id: Int! + name: String! + email: String! +} ``` -The `strawberry.experimental.pydantic.type` decorator accepts a Pydantic model -and wraps a class that contains dataclass style fields with `strawberry.auto` as -the type annotation. The fields marked with `strawberry.auto` will inherit their -types from the Pydantic model. +The private fields are still accessible in Python code for use in resolvers or +business logic: + +```python +@strawberry.type +class Query: + @strawberry.field + def get_user(self) -> User: + user = User(id=1, name="John", password="secret", email="john@example.com") + # Can access private field in Python + if user.password: + return user + return None +``` -If you want to include all of the fields from your Pydantic model, you can -instead pass `all_fields=True` to the decorator. +## Advanced Usage --> **Note** Care should be taken to avoid accidentally exposing fields that -> -weren't meant to be exposed on an API using this feature. +### Nested Types + +Pydantic models can contain other Pydantic models: + +```python +@strawberry.pydantic.type +class Address(BaseModel): + street: str + city: str + zipcode: str + + +@strawberry.pydantic.type +class User(BaseModel): + name: str + address: Address +``` + +### Lists and Collections + +Lists of Pydantic models work seamlessly: ```python +from typing import List + + +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int + + +@strawberry.type +class Query: + @strawberry.field + def get_users(self) -> List[User]: + return [User(name="John", age=30), User(name="Jane", age=25)] +``` + +### Validation + +Pydantic validation is automatically applied to input types: + +```python +from pydantic import validator + + +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + age: int + + @validator("age") + def validate_age(cls, v): + if v < 0: + raise ValueError("Age must be non-negative") + return v +``` + +### Field Directives and Customization + +You can use `strawberry.field()` with `Annotated` types to add GraphQL-specific +features like directives, permissions, and deprecation to individual Pydantic +model fields: + +```python +from typing import Annotated +from pydantic import BaseModel, Field import strawberry -from .models import User + +@strawberry.schema_directive( + locations=[strawberry.schema_directive.Location.FIELD_DEFINITION] +) +class Sensitive: + reason: str -@strawberry.experimental.pydantic.type(model=User, all_fields=True) -class UserType: +@strawberry.schema_directive( + locations=[strawberry.schema_directive.Location.FIELD_DEFINITION] +) +class Range: + min: int + max: int + + +@strawberry.pydantic.type +class User(BaseModel): + # Regular field - uses Pydantic description + name: Annotated[str, Field(description="The user's full name")] + + # Field with directive + email: Annotated[str, strawberry.field(directives=[Sensitive(reason="PII")])] + + # Field with multiple directives and Pydantic features + age: Annotated[ + int, + Field(alias="userAge", description="User's age"), + strawberry.field(directives=[Range(min=0, max=150)]), + ] + + # Field with permissions + phone: Annotated[ + str, + strawberry.field( + permission_classes=[IsAuthenticated], + directives=[Sensitive(reason="Contact Info")], + ), + ] + + # Deprecated field + old_id: Annotated[int, strawberry.field(deprecation_reason="Use 'id' instead")] +``` + +#### Field Customization Options + +When using `strawberry.field()` with Pydantic models, you can specify: + +- **`directives`**: List of GraphQL directives to apply to the field +- **`permission_classes`**: List of permission classes for field-level + authorization +- **`deprecation_reason`**: Mark a field as deprecated with a reason +- **`description`**: Override the Pydantic field description for GraphQL +- **`name`**: Override the GraphQL field name (takes precedence over Pydantic + aliases) + +#### Input Types with Directives + +Field directives work with input types too: + +```python +@strawberry.schema_directive( + locations=[strawberry.schema_directive.Location.INPUT_FIELD_DEFINITION] +) +class Validate: + pattern: str + + +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + email: Annotated[ + str, strawberry.field(directives=[Validate(pattern=r"^[^@]+@[^@]+\.[^@]+")]) + ] +``` + +## Conversion Methods + +Decorated models automatically get conversion methods: + +```python +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int + + +# Create from existing Pydantic instance +pydantic_user = User(name="John", age=30) +strawberry_user = User.from_pydantic(pydantic_user) + +# Convert back to Pydantic +converted_back = strawberry_user.to_pydantic() +``` + +## Migration from Experimental + +If you're using the experimental Pydantic integration, here's how to migrate: + +### Before (Experimental) + +```python +from strawberry.experimental.pydantic import type as pydantic_type + + +class UserModel(BaseModel): + name: str + age: int + + +@pydantic_type(UserModel, all_fields=True) +class User: pass ``` -By default, computed fields are excluded. To also include all computed fields -pass `include_computed=True` to the decorator. +### After (First-class) + +```python +@strawberry.pydantic.type +class User(BaseModel): + name: str + age: int +``` + +## Complete Example ```python +from pydantic import BaseModel, Field, validator +from typing import List, Optional import strawberry -from .models import User +@strawberry.pydantic.type +class User(BaseModel): + id: int + name: str = Field(description="The user's full name") + email: str + age: int = Field(ge=0, description="The user's age in years") + is_active: bool = True + tags: List[str] = Field(default_factory=list) -@strawberry.experimental.pydantic.type( - model=User, all_fields=True, include_computed=True -) + +@strawberry.pydantic.input +class CreateUserInput(BaseModel): + name: str + email: str + age: int + tags: Optional[List[str]] = None + + @validator("age") + def validate_age(cls, v): + if v < 0: + raise ValueError("Age must be non-negative") + return v + + +@strawberry.type +class Query: + @strawberry.field + def get_user(self, id: int) -> Optional[User]: + return User( + id=id, + name="John Doe", + email="john@example.com", + age=30, + tags=["developer", "python"], + ) + + +@strawberry.type +class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User( + id=1, + name=input.name, + email=input.email, + age=input.age, + tags=input.tags or [], + ) + + +schema = strawberry.Schema(query=Query, mutation=Mutation) +``` + +--- + +# Experimental Pydantic Support (Deprecated) + +The experimental Pydantic integration is deprecated in favor of the first-class +support above. The experimental integration will be removed in a future version. + +## Experimental Usage + +The experimental integration required creating separate wrapper classes: + +```python +from strawberry.experimental.pydantic import type as pydantic_type + + +class UserModel(BaseModel): + id: int + name: str + signup_ts: Optional[datetime] = None + friends: List[int] = [] + + +@pydantic_type(model=UserModel) +class UserType: + id: strawberry.auto + name: strawberry.auto + friends: strawberry.auto + + +# Or include all fields +@pydantic_type(model=UserModel, all_fields=True) class UserType: pass ``` -## Input types +### Input types Input types are similar to types; we can create one by using the `strawberry.experimental.pydantic.input` decorator: diff --git a/strawberry/__init__.py b/strawberry/__init__.py index 3cedd7c9b8..db51564877 100644 --- a/strawberry/__init__.py +++ b/strawberry/__init__.py @@ -4,7 +4,7 @@ specification and allow for a more natural way of defining GraphQL schemas. """ -from . import experimental, federation, relay +from . import experimental, federation, pydantic, relay from .directive import directive, directive_field from .parent import Parent from .permission import BasePermission @@ -54,6 +54,7 @@ "interface", "lazy", "mutation", + "pydantic", "relay", "scalar", "schema_directive", diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index baf4ea14a1..9fce2d594a 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -8,7 +8,6 @@ import pydantic from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION - from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError if TYPE_CHECKING: diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index 63727720ed..f6bb4c2b86 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -12,7 +12,6 @@ ) from pydantic import BaseModel - from strawberry.experimental.pydantic._compat import ( CompatModelField, PydanticCompat, diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 447dcd9e6a..6122386329 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -2,7 +2,6 @@ from typing import Annotated, Any, Union from pydantic import BaseModel - from strawberry.experimental.pydantic._compat import ( PydanticCompat, get_args, diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index adaef2ea13..565217a028 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -10,7 +10,6 @@ ) from pydantic import BaseModel - from strawberry.experimental.pydantic._compat import ( CompatModelField, PydanticCompat, diff --git a/strawberry/pydantic/__init__.py b/strawberry/pydantic/__init__.py new file mode 100644 index 0000000000..ea5bd7f81a --- /dev/null +++ b/strawberry/pydantic/__init__.py @@ -0,0 +1,22 @@ +"""Strawberry Pydantic integration. + +This module provides first-class support for Pydantic models in Strawberry GraphQL. +You can directly decorate Pydantic BaseModel classes to create GraphQL types. + +Example: + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int +""" + +from .error import Error +from .object_type import input as input_decorator +from .object_type import interface +from .object_type import type as type_decorator + +# Re-export with proper names +input = input_decorator +type = type_decorator + +__all__ = ["Error", "input", "interface", "type"] diff --git a/strawberry/pydantic/error.py b/strawberry/pydantic/error.py new file mode 100644 index 0000000000..77223e0335 --- /dev/null +++ b/strawberry/pydantic/error.py @@ -0,0 +1,51 @@ +"""Generic error type for Pydantic validation errors in Strawberry GraphQL. + +This module provides a generic Error type that can be used to represent +Pydantic validation errors in GraphQL responses. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from strawberry.types.object_type import type as strawberry_type + +if TYPE_CHECKING: + from pydantic import ValidationError + + +@strawberry_type +class ErrorDetail: + """Represents a single validation error detail.""" + + type: str + loc: list[str] + msg: str + + +@strawberry_type +class Error: + """Generic error type for Pydantic validation errors.""" + + errors: list[ErrorDetail] + + @staticmethod + def from_validation_error(exc: ValidationError) -> Error: + """Create an Error instance from a Pydantic ValidationError. + + Args: + exc: The Pydantic ValidationError to convert + + Returns: + An Error instance containing all validation errors + """ + return Error( + errors=[ + ErrorDetail( + type=error["type"], + loc=[str(loc) for loc in error["loc"]], + msg=error["msg"], + ) + for error in exc.errors() + ] + ) diff --git a/strawberry/pydantic/exceptions.py b/strawberry/pydantic/exceptions.py new file mode 100644 index 0000000000..0fa71ea882 --- /dev/null +++ b/strawberry/pydantic/exceptions.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from pydantic import BaseModel + + +class UnregisteredTypeException(Exception): + def __init__(self, type: type[BaseModel]) -> None: + message = ( + f"Cannot find a Strawberry Type for {type} did you forget to register it?" + ) + + super().__init__(message) diff --git a/strawberry/pydantic/fields.py b/strawberry/pydantic/fields.py new file mode 100644 index 0000000000..d10ca0d729 --- /dev/null +++ b/strawberry/pydantic/fields.py @@ -0,0 +1,217 @@ +"""Field processing utilities for Pydantic models in Strawberry GraphQL. + +This module provides functions to extract and process fields from Pydantic BaseModel +classes, converting them to StrawberryField instances that can be used in GraphQL schemas. +""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, get_args, get_origin +from typing import Union as TypingUnion +from typing import _GenericAlias as TypingGenericAlias + +from strawberry.annotation import StrawberryAnnotation +from strawberry.experimental.pydantic._compat import PydanticCompat +from strawberry.experimental.pydantic.utils import get_default_factory_for_field +from strawberry.types.field import StrawberryField +from strawberry.types.private import is_private +from strawberry.utils.typing import is_union + +from .exceptions import UnregisteredTypeException + +if TYPE_CHECKING: + from pydantic import BaseModel + from pydantic.fields import FieldInfo + +from strawberry.experimental.pydantic._compat import lenient_issubclass + + +def _extract_strawberry_field_from_annotation( + annotation: Any, +) -> StrawberryField | None: + """Extract StrawberryField from an Annotated type annotation. + + Args: + annotation: The type annotation, possibly Annotated[Type, strawberry.field(...)] + + Returns: + StrawberryField instance if found in annotation metadata, None otherwise + """ + # Check if this is an Annotated type + if hasattr(annotation, "__metadata__"): + # Look for StrawberryField in the metadata + for metadata_item in annotation.__metadata__: + if isinstance(metadata_item, StrawberryField): + return metadata_item + + return None + + +def replace_pydantic_types(type_: Any, is_input: bool) -> Any: + """Replace Pydantic types with their Strawberry equivalents for first-class integration.""" + from pydantic import BaseModel + + if lenient_issubclass(type_, BaseModel): + if hasattr(type_, "__strawberry_definition__"): + return type_ + + raise UnregisteredTypeException(type_) + + return type_ + + +def replace_types_recursively( + type_: Any, + is_input: bool, + compat: PydanticCompat, +) -> Any: + """Recursively replace Pydantic types with their Strawberry equivalents.""" + # For now, use a simpler approach similar to the experimental module + basic_type = compat.get_basic_type(type_) + replaced_type = replace_pydantic_types(basic_type, is_input) + + origin = get_origin(type_) + + if not origin or not hasattr(type_, "__args__"): + return replaced_type + + converted = tuple( + replace_types_recursively(t, is_input=is_input, compat=compat) + for t in get_args(replaced_type) + ) + + # Handle special cases for typing generics + if isinstance(replaced_type, TypingGenericAlias): + return TypingGenericAlias(origin, converted) + if is_union(replaced_type): + return TypingUnion[converted] + + # Fallback to origin[converted] for standard generic types + return origin[converted] + + +def get_type_for_field(field: FieldInfo, is_input: bool, compat: PydanticCompat) -> Any: + """Get the GraphQL type for a Pydantic field.""" + return replace_types_recursively(field.outer_type_, is_input, compat=compat) + + +def _get_pydantic_fields( + cls: type[BaseModel], + original_type_annotations: dict[str, type[Any]], + is_input: bool = False, + include_computed: bool = False, +) -> list[StrawberryField]: + """Extract StrawberryFields from a Pydantic BaseModel class. + + This function processes a Pydantic BaseModel and extracts its fields, + converting them to StrawberryField instances that can be used in GraphQL schemas. + All fields from the Pydantic model are included by default, except those marked + with strawberry.Private. + + Fields can be customized using strawberry.field() overrides: + + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int = strawberry.field(directives=[SomeDirective()]) + + Args: + cls: The Pydantic BaseModel class to extract fields from + original_type_annotations: Type annotations that may override field types + is_input: Whether this is for an input type + include_computed: Whether to include computed fields + + Returns: + List of StrawberryField instances + """ + fields: list[StrawberryField] = [] + + # Get compatibility layer for this model + compat = PydanticCompat.from_model(cls) + + # Extract Pydantic model fields + model_fields = compat.get_model_fields(cls, include_computed=include_computed) + + # Get annotations from the class to check for strawberry.Private and strawberry.field() overrides + existing_annotations = getattr(cls, "__annotations__", {}) + + # Process each field from the Pydantic model + for field_name, pydantic_field in model_fields.items(): + # Check if this field is marked as private or has strawberry.field() metadata + strawberry_override = None + if field_name in existing_annotations: + field_annotation = existing_annotations[field_name] + + # Skip private fields - they shouldn't be included in GraphQL schema + if is_private(field_annotation): + continue + + # Check for strawberry.field() in Annotated metadata + strawberry_override = _extract_strawberry_field_from_annotation( + field_annotation + ) + + # Get the field type from the Pydantic model + field_type = get_type_for_field(pydantic_field, is_input, compat=compat) + + # Start with values from Pydantic field + graphql_name = pydantic_field.alias if pydantic_field.has_alias else None + description = pydantic_field.description + directives = [] + permission_classes = [] + extensions = [] + deprecation_reason = None + + # If there's a strawberry.field() override, merge its values + if strawberry_override: + # strawberry.field() overrides take precedence for GraphQL-specific settings + if strawberry_override.graphql_name is not None: + graphql_name = strawberry_override.graphql_name + if strawberry_override.description is not None: + description = strawberry_override.description + if strawberry_override.directives: + directives = list(strawberry_override.directives) + if strawberry_override.permission_classes: + permission_classes = list(strawberry_override.permission_classes) + if strawberry_override.extensions: + extensions = list(strawberry_override.extensions) + if strawberry_override.deprecation_reason is not None: + deprecation_reason = strawberry_override.deprecation_reason + + strawberry_field = StrawberryField( + python_name=field_name, + graphql_name=graphql_name, + type_annotation=StrawberryAnnotation.from_annotation(field_type), + description=description, + default_factory=get_default_factory_for_field( + pydantic_field, compat=compat + ), + directives=directives, + permission_classes=permission_classes, + extensions=extensions, + deprecation_reason=deprecation_reason, + ) + + # Set the origin module for proper type resolution + origin = cls + module = sys.modules[origin.__module__] + + if ( + isinstance(strawberry_field.type_annotation, StrawberryAnnotation) + and strawberry_field.type_annotation.namespace is None + ): + strawberry_field.type_annotation.namespace = module.__dict__ + + strawberry_field.origin = origin + + fields.append(strawberry_field) + + return fields + + +__all__ = [ + "_get_pydantic_fields", + "replace_pydantic_types", + "replace_types_recursively", +] diff --git a/strawberry/pydantic/object_type.py b/strawberry/pydantic/object_type.py new file mode 100644 index 0000000000..e7cf2ee30c --- /dev/null +++ b/strawberry/pydantic/object_type.py @@ -0,0 +1,315 @@ +"""Object type decorators for Pydantic models in Strawberry GraphQL. + +This module provides decorators to convert Pydantic BaseModel classes directly +into GraphQL types, inputs, and interfaces without requiring a separate wrapper class. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Callable, Optional, Union, overload + +if TYPE_CHECKING: + import builtins + from collections.abc import Sequence + +from strawberry.types.base import StrawberryObjectDefinition +from strawberry.types.cast import get_strawberry_type_cast +from strawberry.utils.str_converters import to_camel_case + +from .fields import _get_pydantic_fields + +if TYPE_CHECKING: + from graphql import GraphQLResolveInfo + + from pydantic import BaseModel + + +def _get_interfaces(cls: builtins.type[Any]) -> list[StrawberryObjectDefinition]: + """Extract interfaces from a class's inheritance hierarchy.""" + interfaces: list[StrawberryObjectDefinition] = [] + + for base in cls.__mro__[1:]: # Exclude current class + if hasattr(base, "__strawberry_definition__"): + type_definition = base.__strawberry_definition__ + if type_definition.is_interface: + interfaces.append(type_definition) + + return interfaces + + +def _process_pydantic_type( + cls: type[BaseModel], + *, + name: Optional[str] = None, + is_input: bool = False, + is_interface: bool = False, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + include_computed: bool = False, +) -> type[BaseModel]: + """Process a Pydantic BaseModel class and add GraphQL metadata. + + Args: + cls: The Pydantic BaseModel class to process + name: The GraphQL type name (defaults to class name) + is_input: Whether this is an input type + is_interface: Whether this is an interface type + description: The GraphQL type description + directives: GraphQL directives to apply + include_computed: Whether to include computed fields + + Returns: + The processed BaseModel class with GraphQL metadata + """ + # Get the GraphQL type name + name = name or to_camel_case(cls.__name__) + + # Extract fields using our custom function + # All fields from the Pydantic model are included by default, except strawberry.Private fields + fields = _get_pydantic_fields( + cls=cls, + original_type_annotations={}, + is_input=is_input, + include_computed=include_computed, + ) + + # Get interfaces from inheritance hierarchy + interfaces = _get_interfaces(cls) + + # Create the is_type_of method for proper type resolution + def is_type_of(obj: Any, _info: GraphQLResolveInfo) -> bool: + if (type_cast := get_strawberry_type_cast(obj)) is not None: + return type_cast is cls + return isinstance(obj, cls) + + # Create the GraphQL type definition + cls.__strawberry_definition__ = StrawberryObjectDefinition( # type: ignore + name=name, + is_input=is_input, + is_interface=is_interface, + interfaces=interfaces, + description=description, + directives=directives, + origin=cls, + extend=False, + fields=fields, + is_type_of=is_type_of, + resolve_type=getattr(cls, "resolve_type", None), + ) + + # Add the is_type_of method to the class for testing purposes + cls.is_type_of = is_type_of # type: ignore + + return cls + + +@overload +def type( + cls: type[BaseModel], + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + include_computed: bool = False, +) -> type[BaseModel]: ... + + +@overload +def type( + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + include_computed: bool = False, +) -> Callable[[type[BaseModel]], type[BaseModel]]: ... + + +def type( + cls: Optional[type[BaseModel]] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + include_computed: bool = False, +) -> Union[type[BaseModel], Callable[[type[BaseModel]], type[BaseModel]]]: + """Decorator to convert a Pydantic BaseModel directly into a GraphQL type. + + This decorator allows you to use Pydantic models directly as GraphQL types + without needing to create a separate wrapper class. + + Args: + cls: The Pydantic BaseModel class to convert + name: The GraphQL type name (defaults to class name) + description: The GraphQL type description + directives: GraphQL directives to apply to the type + include_computed: Whether to include computed fields + + Returns: + The decorated BaseModel class with GraphQL metadata + + Example: + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int + + # All fields from the Pydantic model will be included in the GraphQL type + + # You can also use strawberry.field() for field-level customization: + @strawberry.pydantic.type + class User(BaseModel): + name: str + age: int = strawberry.field(directives=[SomeDirective()]) + """ + + def wrap(cls: type[BaseModel]) -> type[BaseModel]: + return _process_pydantic_type( + cls, + name=name, + is_input=False, + is_interface=False, + description=description, + directives=directives, + include_computed=include_computed, + ) + + if cls is None: + return wrap + + return wrap(cls) + + +@overload +def input( + cls: type[BaseModel], + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), +) -> type[BaseModel]: ... + + +@overload +def input( + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), +) -> Callable[[type[BaseModel]], type[BaseModel]]: ... + + +def input( + cls: Optional[type[BaseModel]] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), +) -> Union[type[BaseModel], Callable[[type[BaseModel]], type[BaseModel]]]: + """Decorator to convert a Pydantic BaseModel directly into a GraphQL input type. + + This decorator allows you to use Pydantic models directly as GraphQL input types + without needing to create a separate wrapper class. + + Args: + cls: The Pydantic BaseModel class to convert + name: The GraphQL input type name (defaults to class name) + description: The GraphQL input type description + directives: GraphQL directives to apply to the input type + + Returns: + The decorated BaseModel class with GraphQL input metadata + + Example: + @strawberry.pydantic.input + class CreateUserInput(BaseModel): + name: str + age: int + + # All fields from the Pydantic model will be included in the GraphQL input type + """ + + def wrap(cls: type[BaseModel]) -> type[BaseModel]: + return _process_pydantic_type( + cls, + name=name, + is_input=True, + is_interface=False, + description=description, + directives=directives, + include_computed=False, # Input types don't need computed fields + ) + + if cls is None: + return wrap + + return wrap(cls) + + +@overload +def interface( + cls: type[BaseModel], + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + include_computed: bool = False, +) -> type[BaseModel]: ... + + +@overload +def interface( + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + include_computed: bool = False, +) -> Callable[[type[BaseModel]], type[BaseModel]]: ... + + +def interface( + cls: Optional[type[BaseModel]] = None, + *, + name: Optional[str] = None, + description: Optional[str] = None, + directives: Optional[Sequence[object]] = (), + include_computed: bool = False, +) -> Union[type[BaseModel], Callable[[type[BaseModel]], type[BaseModel]]]: + """Decorator to convert a Pydantic BaseModel directly into a GraphQL interface. + + This decorator allows you to use Pydantic models directly as GraphQL interfaces + without needing to create a separate wrapper class. + + Args: + cls: The Pydantic BaseModel class to convert + name: The GraphQL interface name (defaults to class name) + description: The GraphQL interface description + directives: GraphQL directives to apply to the interface + include_computed: Whether to include computed fields + + Returns: + The decorated BaseModel class with GraphQL interface metadata + + Example: + @strawberry.pydantic.interface + class Node(BaseModel): + id: str + """ + + def wrap(cls: type[BaseModel]) -> type[BaseModel]: + return _process_pydantic_type( + cls, + name=name, + is_input=False, + is_interface=True, + description=description, + directives=directives, + include_computed=include_computed, + ) + + if cls is None: + return wrap + + return wrap(cls) + + +__all__ = ["input", "interface", "type"] diff --git a/strawberry/schema/schema_converter.py b/strawberry/schema/schema_converter.py index 70865604d0..109b014e58 100644 --- a/strawberry/schema/schema_converter.py +++ b/strawberry/schema/schema_converter.py @@ -1,5 +1,6 @@ from __future__ import annotations +import contextlib import dataclasses import sys from functools import partial, reduce @@ -731,14 +732,27 @@ def extension_resolver( ) -> Any: # parse field arguments into Strawberry input types and convert # field names to Python equivalents - field_args, field_kwargs = get_arguments( - field=field, - source=_source, - info=info, - kwargs=kwargs, - config=self.config, - scalar_registry=self.scalar_registry, - ) + try: + field_args, field_kwargs = get_arguments( + field=field, + source=_source, + info=info, + kwargs=kwargs, + config=self.config, + scalar_registry=self.scalar_registry, + ) + except Exception as exc: + # Try to import Pydantic ValidationError + with contextlib.suppress(ImportError): + from pydantic import ValidationError + + if isinstance( + exc, ValidationError + ) and self._should_convert_validation_error(field): + from strawberry.pydantic import Error + + return Error.from_validation_error(exc) + raise resolver_requested_info = False if "info" in field_kwargs: @@ -799,6 +813,22 @@ async def _async_resolver( _resolver._is_default = not field.base_resolver # type: ignore return _resolver + def _should_convert_validation_error(self, field: StrawberryField) -> bool: + """Check if field return type is a Union containing strawberry.pydantic.Error.""" + from strawberry.types.union import StrawberryUnion + + field_type = field.type + if isinstance(field_type, StrawberryUnion): + # Import Error dynamically to avoid circular imports + try: + from strawberry.pydantic import Error + + return any(union_type is Error for union_type in field_type.types) + except ImportError: + # If strawberry.pydantic doesn't exist or Error isn't available + return False + return False + def from_scalar(self, scalar: type) -> GraphQLScalarType: from strawberry.relay.types import GlobalID diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index 91e6d7317e..b34b69b1d7 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -3,7 +3,6 @@ from typing import Optional, Union import pydantic - import strawberry from tests.experimental.pydantic.utils import needs_pydantic_v1 diff --git a/tests/experimental/pydantic/schema/test_computed.py b/tests/experimental/pydantic/schema/test_computed.py index 63e43b03fb..35e21e7d9f 100644 --- a/tests/experimental/pydantic/schema/test_computed.py +++ b/tests/experimental/pydantic/schema/test_computed.py @@ -1,10 +1,10 @@ import textwrap -import pydantic import pytest -from pydantic.version import VERSION as PYDANTIC_VERSION +import pydantic import strawberry +from pydantic.version import VERSION as PYDANTIC_VERSION IS_PYDANTIC_V2: bool = PYDANTIC_VERSION.startswith("2.") diff --git a/tests/experimental/pydantic/schema/test_defaults.py b/tests/experimental/pydantic/schema/test_defaults.py index be761ae87c..1d85f460e0 100644 --- a/tests/experimental/pydantic/schema/test_defaults.py +++ b/tests/experimental/pydantic/schema/test_defaults.py @@ -2,7 +2,6 @@ from typing import Optional import pydantic - import strawberry from strawberry.printer import print_schema from tests.conftest import skip_if_gql_32 diff --git a/tests/experimental/pydantic/schema/test_federation.py b/tests/experimental/pydantic/schema/test_federation.py index db94a8e336..925fdb8f71 100644 --- a/tests/experimental/pydantic/schema/test_federation.py +++ b/tests/experimental/pydantic/schema/test_federation.py @@ -1,8 +1,7 @@ import typing -from pydantic import BaseModel - import strawberry +from pydantic import BaseModel from strawberry.federation.schema_directives import Key diff --git a/tests/experimental/pydantic/schema/test_forward_reference.py b/tests/experimental/pydantic/schema/test_forward_reference.py index ebc94d4b37..23ad750b51 100644 --- a/tests/experimental/pydantic/schema/test_forward_reference.py +++ b/tests/experimental/pydantic/schema/test_forward_reference.py @@ -4,7 +4,6 @@ from typing import Optional import pydantic - import strawberry diff --git a/tests/experimental/pydantic/schema/test_mutation.py b/tests/experimental/pydantic/schema/test_mutation.py index d225d75523..356bbd1438 100644 --- a/tests/experimental/pydantic/schema/test_mutation.py +++ b/tests/experimental/pydantic/schema/test_mutation.py @@ -1,9 +1,9 @@ from typing import Union import pydantic - import strawberry from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V2 +from strawberry.pydantic import Error def test_mutation(): @@ -157,20 +157,14 @@ def create_user(self, input: CreateUserInput) -> UserType: def test_mutation_with_validation_and_error_type(): - class User(pydantic.BaseModel): + # Use the new first-class Pydantic support with automatic validation + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): name: pydantic.constr(min_length=2) - @strawberry.experimental.pydantic.input(User) - class CreateUserInput: - name: strawberry.auto - - @strawberry.experimental.pydantic.type(User) - class UserType: - name: strawberry.auto - - @strawberry.experimental.pydantic.error_type(User) - class UserError: - name: strawberry.auto + @strawberry.pydantic.type + class UserType(pydantic.BaseModel): + name: str @strawberry.type class Query: @@ -179,19 +173,10 @@ class Query: @strawberry.type class Mutation: @strawberry.mutation - def create_user(self, input: CreateUserInput) -> Union[UserType, UserError]: - try: - data = input.to_pydantic() - except pydantic.ValidationError as e: - args: dict[str, list[str]] = {} - for error in e.errors(): - field = error["loc"][0] # currently doesn't support nested errors - field_errors = args.get(field, []) - field_errors.append(error["msg"]) - args[field] = field_errors - return UserError(**args) - else: - return UserType(name=data.name) + def create_user(self, input: CreateUserInput) -> Union[UserType, Error]: + # If we get here, validation passed + # Convert to UserType with valid data + return UserType(name=input.name) schema = strawberry.Schema(query=Query, mutation=Mutation) @@ -201,8 +186,12 @@ def create_user(self, input: CreateUserInput) -> Union[UserType, UserError]: ... on UserType { name } - ... on UserError { - nameErrors: name + ... on Error { + errors { + type + loc + msg + } } } } @@ -210,14 +199,18 @@ def create_user(self, input: CreateUserInput) -> Union[UserType, UserError]: result = schema.execute_sync(query) - assert result.errors is None + assert result.errors is None # No GraphQL errors assert result.data["createUser"].get("name") is None + # Check that validation error was converted to Error type + assert len(result.data["createUser"]["errors"]) == 1 + assert result.data["createUser"]["errors"][0]["type"] == "string_too_short" + assert result.data["createUser"]["errors"][0]["loc"] == ["name"] + if IS_PYDANTIC_V2: - assert result.data["createUser"]["nameErrors"] == [ - ("String should have at least 2 characters") - ] + assert "at least 2 characters" in result.data["createUser"]["errors"][0]["msg"] else: - assert result.data["createUser"]["nameErrors"] == [ - ("ensure this value has at least 2 characters") - ] + assert ( + "ensure this value has at least 2 characters" + in result.data["createUser"]["errors"][0]["msg"] + ) diff --git a/tests/experimental/pydantic/test_basic.py b/tests/experimental/pydantic/test_basic.py index 99fd440042..4e12c271f5 100644 --- a/tests/experimental/pydantic/test_basic.py +++ b/tests/experimental/pydantic/test_basic.py @@ -2,9 +2,9 @@ from enum import Enum from typing import Annotated, Any, Optional, Union -import pydantic import pytest +import pydantic import strawberry from strawberry.experimental.pydantic.exceptions import MissingFieldsListError from strawberry.schema_directive import Location diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 8216f4f038..9fbdf0e687 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -6,9 +6,9 @@ from typing import Any, NewType, Optional, TypeVar, Union import pytest -from pydantic import BaseModel, Field, ValidationError import strawberry +from pydantic import BaseModel, Field, ValidationError from strawberry.experimental.pydantic._compat import ( IS_PYDANTIC_V2, CompatModelField, diff --git a/tests/experimental/pydantic/test_error_type.py b/tests/experimental/pydantic/test_error_type.py index ce5893ca02..0b9e34c732 100644 --- a/tests/experimental/pydantic/test_error_type.py +++ b/tests/experimental/pydantic/test_error_type.py @@ -1,8 +1,8 @@ from typing import Optional -import pydantic import pytest +import pydantic import strawberry from strawberry.experimental.pydantic.exceptions import MissingFieldsListError from strawberry.types.base import ( diff --git a/tests/experimental/pydantic/test_fields.py b/tests/experimental/pydantic/test_fields.py index 0184463a37..c6a03d2c27 100644 --- a/tests/experimental/pydantic/test_fields.py +++ b/tests/experimental/pydantic/test_fields.py @@ -1,11 +1,11 @@ import re from typing_extensions import Literal -import pydantic import pytest -from pydantic import BaseModel, ValidationError, conlist +import pydantic import strawberry +from pydantic import BaseModel, ValidationError, conlist from strawberry.experimental.pydantic._compat import IS_PYDANTIC_V1 from strawberry.types.base import StrawberryObjectDefinition, StrawberryOptional from tests.experimental.pydantic.utils import needs_pydantic_v1, needs_pydantic_v2 diff --git a/tests/pydantic/__init__.py b/tests/pydantic/__init__.py new file mode 100644 index 0000000000..e7ba6325e4 --- /dev/null +++ b/tests/pydantic/__init__.py @@ -0,0 +1 @@ +# Test package for Strawberry Pydantic integration diff --git a/tests/pydantic/test_aliases.py b/tests/pydantic/test_aliases.py new file mode 100644 index 0000000000..d760b2847d --- /dev/null +++ b/tests/pydantic/test_aliases.py @@ -0,0 +1,39 @@ +from typing import Annotated + +from inline_snapshot import snapshot + +import pydantic +import strawberry + + +def test_pydantic_field_aliases_in_execution(): + """Test that Pydantic field aliases work in GraphQL execution.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(alias="fullName")] + age: Annotated[int, pydantic.Field(alias="yearsOld")] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + # When using aliases, we need to create the User with the aliased field names + return User(fullName="John", yearsOld=30) + + schema = strawberry.Schema(query=Query) + + # Query using the aliased field names + query = """ + query { + getUser { + fullName + yearsOld + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot({"getUser": {"fullName": "John", "yearsOld": 30}}) diff --git a/tests/pydantic/test_description.py b/tests/pydantic/test_description.py new file mode 100644 index 0000000000..662acc2844 --- /dev/null +++ b/tests/pydantic/test_description.py @@ -0,0 +1,26 @@ +from typing import Annotated + +import pydantic +import strawberry + + +def test_pydantic_field_descriptions_in_schema(): + """Test that Pydantic field descriptions appear in the schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="The user's full name")] + age: Annotated[int, pydantic.Field(description="The user's age in years")] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + # Check that the schema includes field descriptions + schema_str = str(schema) + assert "The user's full name" in schema_str + assert "The user's age in years" in schema_str diff --git a/tests/pydantic/test_error.py b/tests/pydantic/test_error.py new file mode 100644 index 0000000000..a60cce7107 --- /dev/null +++ b/tests/pydantic/test_error.py @@ -0,0 +1,231 @@ +"""Tests for the generic Pydantic Error type.""" + +from typing import Union + +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.pydantic import Error + + +def test_error_type_from_validation_error(): + """Test creating Error from ValidationError.""" + + class UserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0) + + # Test with multiple validation errors + try: + UserInput(name="A", age=-5) + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 2 + + # Check first error (name) + assert error.errors[0].type == "string_too_short" + assert error.errors[0].loc == ["name"] + assert "at least 2 characters" in error.errors[0].msg + + # Check second error (age) + assert error.errors[1].type == "greater_than_equal" + assert error.errors[1].loc == ["age"] + assert "greater than or equal to 0" in error.errors[1].msg + + +def test_error_type_with_nested_fields(): + """Test Error type with nested field validation errors.""" + + class AddressInput(pydantic.BaseModel): + street: pydantic.constr(min_length=5) + city: str + zip_code: pydantic.constr(pattern=r"^\d{5}$") + + class UserInput(pydantic.BaseModel): + name: str + address: AddressInput + + try: + UserInput( + name="John", + address={"street": "Oak", "city": "NYC", "zip_code": "ABC"}, + ) + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 2 + + # Check nested street error + assert error.errors[0].type == "string_too_short" + assert error.errors[0].loc == ["address", "street"] + assert "at least 5 characters" in error.errors[0].msg + + # Check nested zip_code error + assert error.errors[1].type == "string_pattern_mismatch" + assert error.errors[1].loc == ["address", "zip_code"] + + +def test_error_in_mutation_with_union_return(): + """Test using Error in a mutation with union return type.""" + + # Use @strawberry.pydantic.input for automatic validation + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0, le=120) + + @strawberry.type + class CreateUserSuccess: + user_id: int + message: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user( + self, input: CreateUserInput + ) -> Union[CreateUserSuccess, Error]: + # If we get here, validation passed + return CreateUserSuccess( + user_id=1, message=f"User {input.name} created successfully" + ) + + @strawberry.type + class Query: + dummy: str = "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test successful creation + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "John", age: 30 }) { + ... on CreateUserSuccess { + userId + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["userId"] == 1 + assert result.data["createUser"]["message"] == "User John created successfully" + + # Test validation error + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "J", age: -5 }) { + ... on CreateUserSuccess { + userId + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert ( + not result.errors + ) # No GraphQL errors, validation errors are converted to Error type + assert len(result.data["createUser"]["errors"]) == 2 + + # Check first error + assert result.data["createUser"]["errors"][0]["type"] == "string_too_short" + assert result.data["createUser"]["errors"][0]["loc"] == ["name"] + assert "at least 2 characters" in result.data["createUser"]["errors"][0]["msg"] + + # Check second error + assert result.data["createUser"]["errors"][1]["type"] == "greater_than_equal" + assert result.data["createUser"]["errors"][1]["loc"] == ["age"] + + +def test_error_graphql_schema(): + """Test that Error generates correct GraphQL schema.""" + + @strawberry.type + class Query: + @strawberry.field + def test_error(self) -> Error: + # Dummy resolver + return Error(errors=[]) + + schema = strawberry.Schema(query=Query) + + assert str(schema) == snapshot( + """\ +type Error { + errors: [ErrorDetail!]! +} + +type ErrorDetail { + type: String! + loc: [String!]! + msg: String! +} + +type Query { + testError: Error! +}\ +""" + ) + + +def test_error_with_single_validation_error(): + """Test Error type with a single validation error.""" + + class EmailInput(pydantic.BaseModel): + email: pydantic.EmailStr + + try: + EmailInput(email="not-an-email") + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 1 + assert error.errors[0].type in [ + "value_error", + "email", + ] # Depends on Pydantic version + assert error.errors[0].loc == ["email"] + assert "email" in error.errors[0].msg.lower() + + +def test_error_with_list_field_validation(): + """Test Error type with validation errors in list fields.""" + + class TagsInput(pydantic.BaseModel): + tags: list[pydantic.constr(min_length=2)] + + try: + TagsInput(tags=["ok", "a", "good", "b"]) + except pydantic.ValidationError as e: + error = Error.from_validation_error(e) + + assert len(error.errors) == 2 + + # Check errors for short tags + assert error.errors[0].type == "string_too_short" + assert error.errors[0].loc == ["tags", "1"] # Index 1 is "a" + + assert error.errors[1].type == "string_too_short" + assert error.errors[1].loc == ["tags", "3"] # Index 3 is "b" diff --git a/tests/pydantic/test_error_with_pydantic_input.py b/tests/pydantic/test_error_with_pydantic_input.py new file mode 100644 index 0000000000..2e171dc5d0 --- /dev/null +++ b/tests/pydantic/test_error_with_pydantic_input.py @@ -0,0 +1,211 @@ +"""Test Pydantic validation error handling with @strawberry.pydantic.input.""" + +from typing import Union + +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.pydantic import Error + + +def test_pydantic_input_validation_error_converted_to_error(): + """Test that ValidationError from @strawberry.pydantic.input is converted to Error.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0, le=120) + + @strawberry.type + class CreateUserSuccess: + id: int + message: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user( + self, input: CreateUserInput + ) -> Union[CreateUserSuccess, Error]: + # If we get here, validation passed + return CreateUserSuccess( + id=1, message=f"User {input.name} created successfully" + ) + + @strawberry.type + class Query: + dummy: str = "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test successful creation + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "John", age: 30 }) { + ... on CreateUserSuccess { + id + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert not result.errors + assert result.data["createUser"]["id"] == 1 + assert result.data["createUser"]["message"] == "User John created successfully" + + # Test validation error - should be converted to Error type + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "J", age: -5 }) { + ... on CreateUserSuccess { + id + message + } + ... on Error { + errors { + type + loc + msg + } + } + } + } + """ + ) + + assert not result.errors # No GraphQL errors + assert result.data == snapshot( + { + "createUser": { + "errors": [ + { + "type": "string_too_short", + "loc": ["name"], + "msg": "String should have at least 2 characters", + }, + { + "type": "greater_than_equal", + "loc": ["age"], + "msg": "Input should be greater than or equal to 0", + }, + ] + } + } + ) + + +def test_pydantic_input_validation_error_without_error_in_union(): + """Test that ValidationError is still raised if Error is not in the return type.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: pydantic.constr(min_length=2) + age: pydantic.conint(ge=0) + + @strawberry.type + class CreateUserSuccess: + id: int + message: str + + @strawberry.type + class Mutation: + @strawberry.mutation + def create_user(self, input: CreateUserInput) -> CreateUserSuccess: + # If we get here, validation passed + return CreateUserSuccess( + id=1, message=f"User {input.name} created successfully" + ) + + @strawberry.type + class Query: + dummy: str = "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test validation error - should raise GraphQL error + result = schema.execute_sync( + """ + mutation { + createUser(input: { name: "J", age: -5 }) { + id + message + } + } + """ + ) + + assert result.errors + assert len(result.errors) == 1 + assert "validation error" in result.errors[0].message.lower() + + +def test_graphql_schema_with_pydantic_input(): + """Test that the GraphQL schema is correct with Pydantic input.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class UserResult: + success: bool + message: str + + @strawberry.type + class Query: + dummy: str = "dummy" + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> Union[UserResult, Error]: + return UserResult(success=True, message="ok") + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + assert str(schema) == snapshot( + """\ +type Error { + errors: [ErrorDetail!]! +} + +type ErrorDetail { + type: String! + loc: [String!]! + msg: String! +} + +type Mutation { + createUser(input: UserInput!): UserResultError! +} + +type Query { + dummy: String! +} + +input UserInput { + name: String! + age: Int! +} + +type UserResult { + success: Boolean! + message: String! +} + +union UserResultError = UserResult | Error\ +""" + ) diff --git a/tests/pydantic/test_execution.py b/tests/pydantic/test_execution.py new file mode 100644 index 0000000000..6c3cc297ab --- /dev/null +++ b/tests/pydantic/test_execution.py @@ -0,0 +1,619 @@ +from typing import Annotated, Optional + +import pytest + +import pydantic +import strawberry + + +def test_basic_query_execution(): + """Test basic query execution with Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"name": "John", "age": 30}} + + +def test_query_with_optional_fields(): + """Test query execution with optional fields.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Optional[str] = None + age: Optional[int] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", email="john@example.com") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + email + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUser": {"name": "John", "email": "john@example.com", "age": None} + } + + +def test_mutation_with_input_types(): + """Test mutation execution with Pydantic input types.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + email: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + email: Optional[str] = None + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(id=1, name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + email: "alice@example.com" + }) { + id + name + age + email + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == { + "createUser": { + "id": 1, + "name": "Alice", + "age": 25, + "email": "alice@example.com", + } + } + + +def test_mutation_with_partial_input(): + """Test mutation with partial input (optional fields).""" + + @strawberry.pydantic.input + class UpdateUserInput(pydantic.BaseModel): + name: Optional[str] = None + age: Optional[int] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def update_user(self, id: int, input: UpdateUserInput) -> User: + # Simulate updating a user + return User(id=id, name=input.name or "Default Name", age=input.age or 18) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + updateUser(id: 1, input: { + name: "Updated Name" + }) { + id + name + age + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == {"updateUser": {"id": 1, "name": "Updated Name", "age": 18}} + + +def test_nested_pydantic_types(): + """Test nested Pydantic types in queries.""" + + @strawberry.pydantic.type + class Address(pydantic.BaseModel): + street: str + city: str + zipcode: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + address: Address + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + name="John", + age=30, + address=Address(street="123 Main St", city="Anytown", zipcode="12345"), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUser": { + "name": "John", + "age": 30, + "address": {"street": "123 Main St", "city": "Anytown", "zipcode": "12345"}, + } + } + + +def test_list_of_pydantic_types(): + """Test lists of Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_users(self) -> list[User]: + return [ + User(name="John", age=30), + User(name="Jane", age=25), + User(name="Bob", age=35), + ] + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUsers { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUsers": [ + {"name": "John", "age": 30}, + {"name": "Jane", "age": 25}, + {"name": "Bob", "age": 35}, + ] + } + + +def test_pydantic_field_descriptions_in_schema(): + """Test that Pydantic field descriptions appear in the schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="The user's full name")] + age: Annotated[int, pydantic.Field(description="The user's age in years")] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + # Check that the schema includes field descriptions + schema_str = str(schema) + assert "The user's full name" in schema_str + assert "The user's age in years" in schema_str + + +def test_pydantic_field_aliases_in_execution(): + """Test that Pydantic field aliases work in GraphQL execution.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(alias="fullName")] + age: Annotated[int, pydantic.Field(alias="yearsOld")] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + # When using aliases, we need to create the User with the aliased field names + return User(fullName="John", yearsOld=30) + + schema = strawberry.Schema(query=Query) + + # Query using the aliased field names + query = """ + query { + getUser { + fullName + yearsOld + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"fullName": "John", "yearsOld": 30}} + + +def test_pydantic_validation_integration(): + """Test that Pydantic validation works with GraphQL inputs.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with valid input + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + email: "alice@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == { + "createUser": {"name": "Alice", "age": 25, "email": "alice@example.com"} + } + + +def test_complex_pydantic_types_execution(): + """Test complex Pydantic types with various field types.""" + + @strawberry.pydantic.type + class Profile(pydantic.BaseModel): + bio: Optional[str] = None + website: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + email: str + is_active: bool + tags: list[str] = [] + profile: Optional[Profile] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + id=1, + name="John Doe", + email="john@example.com", + is_active=True, + tags=["developer", "python", "graphql"], + profile=Profile( + bio="Software developer", website="https://johndoe.com" + ), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + email + isActive + tags + profile { + bio + website + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == { + "getUser": { + "id": 1, + "name": "John Doe", + "email": "john@example.com", + "isActive": True, + "tags": ["developer", "python", "graphql"], + "profile": {"bio": "Software developer", "website": "https://johndoe.com"}, + } + } + + +def test_pydantic_interface_basic(): + """Test basic Pydantic interface functionality.""" + + @strawberry.pydantic.interface + class Node(pydantic.BaseModel): + id: str + + # Interface requires implementing types for proper execution + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: str + name: str + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id="user_1", name="John") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"id": "user_1", "name": "John"}} + + +def test_error_handling_with_pydantic_validation(): + """Test error handling when Pydantic validation fails.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + + @pydantic.validator("age") + def validate_age(cls, v): + if v < 0: + raise ValueError("Age must be non-negative") + return v + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(name=input.name, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with invalid input (negative age) + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: -5 + }) { + name + age + } + } + """ + + result = schema.execute_sync(mutation) + + # Should handle validation error gracefully + # The exact error handling depends on Strawberry's error handling implementation + assert result.errors or result.data is None + + +@pytest.mark.asyncio +async def test_async_execution_with_pydantic(): + """Test async execution with Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + async def get_user(self) -> User: + # Simulate async operation + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + } + } + """ + + result = await schema.execute(query) + + assert not result.errors + assert result.data == {"getUser": {"name": "John", "age": 30}} + + +def test_strawberry_private_fields_not_in_schema(): + """Test that strawberry.Private fields are not exposed in GraphQL schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + password: strawberry.Private[str] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id=1, name="John", password="secret123") + + schema = strawberry.Schema(query=Query) + + # Check that password field is not in the schema + schema_str = str(schema) + assert "password" not in schema_str + assert "id: Int!" in schema_str + assert "name: String!" in schema_str + + # Test that we can query the exposed fields + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == {"getUser": {"id": 1, "name": "John"}} + + # Test that querying the private field fails + query_with_private = """ + query { + getUser { + id + name + password + } + } + """ + + result = schema.execute_sync(query_with_private) + assert result.errors + assert "Cannot query field 'password'" in str(result.errors[0]) diff --git a/tests/pydantic/test_fields.py b/tests/pydantic/test_fields.py new file mode 100644 index 0000000000..7100f38ba4 --- /dev/null +++ b/tests/pydantic/test_fields.py @@ -0,0 +1,331 @@ +from typing import Annotated + +import pytest +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.pydantic.exceptions import UnregisteredTypeException +from strawberry.schema_directive import Location +from strawberry.types.base import get_object_definition + + +def test_pydantic_field_descriptions(): + """Test that Pydantic field descriptions are preserved.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: Annotated[int, pydantic.Field(description="The user's age")] + name: Annotated[str, pydantic.Field(description="The user's name")] + + definition = get_object_definition(User, strict=True) + + age_field = next(f for f in definition.fields if f.python_name == "age") + name_field = next(f for f in definition.fields if f.python_name == "name") + + assert age_field.description == "The user's age" + assert name_field.description == "The user's name" + + +def test_pydantic_field_aliases(): + """Test that Pydantic field aliases are used as GraphQL names.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: Annotated[int, pydantic.Field(alias="userAge")] + name: Annotated[str, pydantic.Field(alias="userName")] + + definition = get_object_definition(User, strict=True) + + age_field = next(f for f in definition.fields if f.python_name == "age") + name_field = next(f for f in definition.fields if f.python_name == "name") + + assert age_field.graphql_name == "userAge" + assert name_field.graphql_name == "userName" + + +def test_can_use_strawberry_types(): + """Test that Pydantic models can use Strawberry types.""" + + @strawberry.type + class Address: + street: str + city: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + address: Address + + definition = get_object_definition(User, strict=True) + + address_field = next(f for f in definition.fields if f.python_name == "address") + + assert address_field.type is Address + + @strawberry.type + class Query: + @strawberry.field + @staticmethod + def user() -> User: + return User( + name="Rabbit", address=Address(street="123 Main St", city="Wonderland") + ) + + schema = strawberry.Schema(query=Query) + + query = """query { + user { + name + address { + street + city + } + } + }""" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "user": { + "name": "Rabbit", + "address": {"street": "123 Main St", "city": "Wonderland"}, + } + } + ) + + +def test_all_models_need_to_marked_as_strawberry_types(): + class Address(pydantic.BaseModel): + street: str + city: str + + with pytest.raises( + UnregisteredTypeException, + match=( + r"Cannot find a Strawberry Type for did you forget to register it\?" + ), + ): + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + address: Address + + +def test_field_directives_basic(): + """Test that strawberry.field() directives work with Pydantic models using Annotated.""" + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Sensitive: + reason: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: Annotated[int, strawberry.field(directives=[Sensitive(reason="PII")])] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + age_field = next(f for f in definition.fields if f.python_name == "age") + + # Name field should have no directives + assert len(name_field.directives) == 0 + + # Age field should have the Sensitive directive + assert len(age_field.directives) == 1 + assert isinstance(age_field.directives[0], Sensitive) + assert age_field.directives[0].reason == "PII" + + +def test_field_directives_multiple(): + """Test multiple directives on a single field.""" + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Sensitive: + reason: str + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Tag: + name: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Annotated[ + str, + strawberry.field(directives=[Sensitive(reason="PII"), Tag(name="contact")]), + ] + + definition = get_object_definition(User, strict=True) + + email_field = next(f for f in definition.fields if f.python_name == "email") + + # Email field should have both directives + assert len(email_field.directives) == 2 + + sensitive_directive = next( + d for d in email_field.directives if isinstance(d, Sensitive) + ) + tag_directive = next(d for d in email_field.directives if isinstance(d, Tag)) + + assert sensitive_directive.reason == "PII" + assert tag_directive.name == "contact" + + +def test_field_directives_with_pydantic_features(): + """Test that strawberry.field() directives work alongside Pydantic field features.""" + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class Range: + min: int + max: int + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="The user's name")] + age: Annotated[ + int, + pydantic.Field(alias="userAge", description="The user's age"), + strawberry.field(directives=[Range(min=0, max=150)]), + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + age_field = next(f for f in definition.fields if f.python_name == "age") + + # Name field should preserve Pydantic description + assert name_field.description == "The user's name" + assert len(name_field.directives) == 0 + + # Age field should have both Pydantic features and Strawberry directive + assert age_field.description == "The user's age" + assert age_field.graphql_name == "userAge" + assert len(age_field.directives) == 1 + assert isinstance(age_field.directives[0], Range) + assert age_field.directives[0].min == 0 + assert age_field.directives[0].max == 150 + + +def test_field_directives_override_description(): + """Test that strawberry.field() description overrides Pydantic description.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(description="Pydantic description")] + age: Annotated[ + int, + pydantic.Field(description="Pydantic age description"), + strawberry.field(description="Strawberry description override"), + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + age_field = next(f for f in definition.fields if f.python_name == "age") + + # Name field should use Pydantic description + assert name_field.description == "Pydantic description" + + # Age field should use strawberry.field() description override + assert age_field.description == "Strawberry description override" + + +def test_field_directives_with_permissions(): + """Test that strawberry.field() permissions work with Pydantic models.""" + + class IsAuthenticated(strawberry.BasePermission): + message = "User is not authenticated" + + def has_permission(self, source, info, **kwargs): # noqa: ANN003 + return True # Simplified for testing + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Annotated[str, strawberry.field(permission_classes=[IsAuthenticated])] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + email_field = next(f for f in definition.fields if f.python_name == "email") + + # Name field should have no permissions + assert len(name_field.permission_classes) == 0 + + # Email field should have the permission + assert len(email_field.permission_classes) == 1 + assert email_field.permission_classes[0] == IsAuthenticated + + +def test_field_directives_with_deprecation(): + """Test that strawberry.field() deprecation works with Pydantic models.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + old_field: Annotated[ + str, strawberry.field(deprecation_reason="Use name instead") + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + old_field = next(f for f in definition.fields if f.python_name == "old_field") + + # Name field should not be deprecated + assert name_field.deprecation_reason is None + + # Old field should be deprecated + assert old_field.deprecation_reason == "Use name instead" + + +def test_field_directives_input_types(): + """Test that field directives work with Pydantic input types.""" + + @strawberry.schema_directive(locations=[Location.INPUT_FIELD_DEFINITION]) + class Validate: + pattern: str + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + email: Annotated[ + str, strawberry.field(directives=[Validate(pattern=r"^[^@]+@[^@]+\.[^@]+")]) + ] + + definition = get_object_definition(CreateUserInput, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + email_field = next(f for f in definition.fields if f.python_name == "email") + + # Name field should have no directives + assert len(name_field.directives) == 0 + + # Email field should have the validation directive + assert len(email_field.directives) == 1 + assert isinstance(email_field.directives[0], Validate) + assert email_field.directives[0].pattern == r"^[^@]+@[^@]+\.[^@]+" + + +def test_field_directives_graphql_name_override(): + """Test that strawberry.field() can override Pydantic field aliases for GraphQL names.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: Annotated[ + str, + pydantic.Field(alias="pydantic_name"), + strawberry.field(name="strawberry_name"), + ] + + definition = get_object_definition(User, strict=True) + + name_field = next(f for f in definition.fields if f.python_name == "name") + + # strawberry.field() graphql_name should override Pydantic alias + assert name_field.graphql_name == "strawberry_name" diff --git a/tests/pydantic/test_generics.py b/tests/pydantic/test_generics.py new file mode 100644 index 0000000000..411d46e54d --- /dev/null +++ b/tests/pydantic/test_generics.py @@ -0,0 +1,136 @@ +import sys +from typing import Generic, TypeVar + +import pytest +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.types.base import ( + StrawberryList, + StrawberryOptional, + StrawberryTypeVar, + get_object_definition, +) + +T = TypeVar("T") + + +def test_basic_pydantic_generic_fields(): + """Test that pydantic generic models preserve field types correctly.""" + + @strawberry.pydantic.type + class GenericModel(pydantic.BaseModel, Generic[T]): + value: T + name: str = "default" + + definition = get_object_definition(GenericModel, strict=True) + + # Check fields + fields = definition.fields + assert len(fields) == 2 + + value_field = next(f for f in fields if f.python_name == "value") + name_field = next(f for f in fields if f.python_name == "name") + + # The value field should contain a TypeVar (generic parameter) + assert isinstance(value_field.type, StrawberryTypeVar) + assert value_field.type.type_var is T + + # The name field should be concrete + assert name_field.type is str + + +def test_pydantic_generic_with_concrete_type(): + """Test pydantic with a concrete generic instantiation.""" + + class GenericModel(pydantic.BaseModel, Generic[T]): + data: T + + # Create a concrete version by inheriting from GenericModel[int] + @strawberry.pydantic.type + class ConcreteModel(GenericModel[int]): + pass + + definition = get_object_definition(ConcreteModel, strict=True) + + # Verify the field type is concrete + [data_field] = definition.fields + assert data_field.python_name == "data" + assert data_field.type is int + + +def test_pydantic_generic_schema(): + """Test the GraphQL schema generated from pydantic generic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel, Generic[T]): + id: int + data: T + name: str = "default" + + # Create concrete versions + @strawberry.pydantic.type + class UserString(User[str]): + pass + + @strawberry.pydantic.type + class UserInt(User[int]): + pass + + @strawberry.type + class Query: + @strawberry.field + def get_user_string(self) -> UserString: + return UserString(id=1, data="hello", name="test") + + @strawberry.field + def get_user_int(self) -> UserInt: + return UserInt(id=2, data=42, name="test") + + schema = strawberry.Schema(query=Query) + + assert str(schema) == snapshot("""\ +type Query { + getUserString: UserString! + getUserInt: UserInt! +} + +type UserInt { + id: Int! + data: Int! + name: String! +} + +type UserString { + id: Int! + data: String! + name: String! +}\ +""") + + +def test_can_convert_generic_alias_fields_to_strawberry(): + @strawberry.pydantic.type + class Test(pydantic.BaseModel): + list_1d: list[int] + list_2d: list[list[int]] + + fields = get_object_definition(Test, strict=True).fields + assert isinstance(fields[0].type, StrawberryList) + assert isinstance(fields[1].type, StrawberryList) + + +@pytest.mark.skipif( + sys.version_info < (3, 10), + reason="union type expressions were added in python 3.10", +) +def test_can_convert_optional_union_type_expression_fields_to_strawberry(): + @strawberry.pydantic.type + class Test(pydantic.BaseModel): + optional_list: list[int] | None + optional_str: str | None + + fields = get_object_definition(Test, strict=True).fields + assert isinstance(fields[0].type, StrawberryOptional) + assert isinstance(fields[1].type, StrawberryOptional) diff --git a/tests/pydantic/test_inputs.py b/tests/pydantic/test_inputs.py new file mode 100644 index 0000000000..11711be520 --- /dev/null +++ b/tests/pydantic/test_inputs.py @@ -0,0 +1,789 @@ +from typing import Annotated, Optional + +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.types.base import get_object_definition + + +def test_basic_input_type(): + """Test that @strawberry.pydantic.input works.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + age: int + name: str + + definition = get_object_definition(CreateUserInput, strict=True) + + assert definition.name == "CreateUserInput" + assert definition.is_input is True + assert len(definition.fields) == 2 + + +def test_input_type_with_valid_data(): + """Test input type with various valid data scenarios.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + age: int + email: str + is_active: bool = True + tags: list[str] = [] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + email: str + is_active: bool + tags: list[str] + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User( + id=1, + name=input.name, + age=input.age, + email=input.email, + is_active=input.is_active, + tags=input.tags, + ) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with all fields provided + mutation = """ + mutation { + createUser(input: { + name: "John Doe" + age: 30 + email: "john@example.com" + isActive: false + tags: ["developer", "python"] + }) { + id + name + age + email + isActive + tags + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "id": 1, + "name": "John Doe", + "age": 30, + "email": "john@example.com", + "isActive": False, + "tags": ["developer", "python"], + } + } + ) + + # Test with default values + mutation_defaults = """ + mutation { + createUser(input: { + name: "Jane Doe" + age: 25 + email: "jane@example.com" + }) { + id + name + age + email + isActive + tags + } + } + """ + + result = schema.execute_sync(mutation_defaults) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "id": 1, + "name": "Jane Doe", + "age": 25, + "email": "jane@example.com", + "isActive": True, # default value + "tags": [], # default value + } + } + ) + + +def test_input_type_with_invalid_email(): + """Test input type with invalid email format.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(min_length=2, max_length=50)] + age: Annotated[int, pydantic.Field(ge=0, le=150)] + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with invalid email + mutation_invalid_email = """ + mutation { + createUser(input: { + name: "John" + age: 30 + email: "invalid-email" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_invalid_email) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for UserInput +email + String should match pattern '^[^@]+@[^@]+\\.[^@]+$' [type=string_pattern_mismatch, input_value='invalid-email', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/string_pattern_mismatch\ +""") + + +def test_input_type_with_invalid_name_length(): + """Test input type with name validation errors.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(min_length=2, max_length=50)] + age: Annotated[int, pydantic.Field(ge=0, le=150)] + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with name too short + mutation_short_name = """ + mutation { + createUser(input: { + name: "J" + age: 30 + email: "john@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_short_name) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for UserInput +name + String should have at least 2 characters [type=string_too_short, input_value='J', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/string_too_short\ +""") + + +def test_input_type_with_invalid_age_range(): + """Test input type with age validation errors.""" + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: Annotated[str, pydantic.Field(min_length=2, max_length=50)] + age: Annotated[int, pydantic.Field(ge=0, le=150)] + email: Annotated[str, pydantic.Field(pattern=r"^[^@]+@[^@]+\.[^@]+$")] + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + email: str + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User(name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with age out of range (negative) + mutation_negative_age = """ + mutation { + createUser(input: { + name: "John" + age: -5 + email: "john@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_negative_age) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for UserInput +age + Input should be greater than or equal to 0 [type=greater_than_equal, input_value=-5, input_type=int] + For further information visit https://errors.pydantic.dev/2.11/v/greater_than_equal\ +""") + + # Test with age out of range (too high) + mutation_high_age = """ + mutation { + createUser(input: { + name: "John" + age: 200 + email: "john@example.com" + }) { + name + age + email + } + } + """ + + result = schema.execute_sync(mutation_high_age) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for UserInput +age + Input should be less than or equal to 150 [type=less_than_equal, input_value=200, input_type=int] + For further information visit https://errors.pydantic.dev/2.11/v/less_than_equal\ +""") + + +def test_nested_input_types_with_validation(): + """Test nested input types with validation.""" + + @strawberry.pydantic.input + class AddressInput(pydantic.BaseModel): + street: Annotated[str, pydantic.Field(min_length=5)] + city: Annotated[str, pydantic.Field(min_length=2)] + zipcode: Annotated[str, pydantic.Field(pattern=r"^\d{5}$")] + + @strawberry.pydantic.input + class UserInput(pydantic.BaseModel): + name: str + age: Annotated[int, pydantic.Field(ge=18)] # Must be 18 or older + address: AddressInput + + @strawberry.pydantic.type + class Address(pydantic.BaseModel): + street: str + city: str + zipcode: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + address: Address + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: UserInput) -> User: + return User( + name=input.name, + age=input.age, + address=Address( + street=input.address.street, + city=input.address.city, + zipcode=input.address.zipcode, + ), + ) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with valid nested data + mutation_valid = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + address: { + street: "123 Main Street" + city: "New York" + zipcode: "12345" + } + }) { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(mutation_valid) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "name": "Alice", + "age": 25, + "address": { + "street": "123 Main Street", + "city": "New York", + "zipcode": "12345", + }, + } + } + ) + + # Test with invalid nested data (invalid zipcode) + mutation_invalid_zip = """ + mutation { + createUser(input: { + name: "Bob" + age: 30 + address: { + street: "456 Elm Street" + city: "Boston" + zipcode: "1234" # Too short + } + }) { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(mutation_invalid_zip) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for AddressInput +zipcode + String should match pattern '^\\d{5}$' [type=string_pattern_mismatch, input_value='1234', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/string_pattern_mismatch\ +""") + + # Test with invalid nested data (underage) + mutation_underage = """ + mutation { + createUser(input: { + name: "Charlie" + age: 16 # Under 18 + address: { + street: "789 Oak Street" + city: "Chicago" + zipcode: "60601" + } + }) { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(mutation_underage) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for UserInput +age + Input should be greater than or equal to 18 [type=greater_than_equal, input_value=16, input_type=int] + For further information visit https://errors.pydantic.dev/2.11/v/greater_than_equal\ +""") + + +def test_input_type_with_custom_validators(): + """Test input types with custom Pydantic validators.""" + + @strawberry.pydantic.input + class RegistrationInput(pydantic.BaseModel): + username: str + password: str + confirm_password: str + age: int + + @pydantic.field_validator("username") + @classmethod + def username_alphanumeric(cls, v: str) -> str: + if not v.isalnum(): + raise ValueError("Username must be alphanumeric") + if len(v) < 3: + raise ValueError("Username must be at least 3 characters long") + return v + + @pydantic.field_validator("password") + @classmethod + def password_strength(cls, v: str) -> str: + if len(v) < 8: + raise ValueError("Password must be at least 8 characters long") + if not any(c.isupper() for c in v): + raise ValueError("Password must contain at least one uppercase letter") + if not any(c.isdigit() for c in v): + raise ValueError("Password must contain at least one digit") + return v + + @pydantic.field_validator("confirm_password") + @classmethod + def passwords_match(cls, v: str, info: pydantic.ValidationInfo) -> str: + if "password" in info.data and v != info.data["password"]: + raise ValueError("Passwords do not match") + return v + + @pydantic.field_validator("age") + @classmethod + def age_requirement(cls, v: int) -> int: + if v < 13: + raise ValueError("Must be at least 13 years old") + return v + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + username: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def register(self, input: RegistrationInput) -> User: + return User(username=input.username, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with valid input + mutation_valid = """ + mutation { + register(input: { + username: "john123" + password: "SecurePass123" + confirmPassword: "SecurePass123" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_valid) + + assert not result.errors + assert result.data == snapshot({"register": {"username": "john123", "age": 25}}) + + # Test with non-alphanumeric username + mutation_invalid_username = """ + mutation { + register(input: { + username: "john@123" + password: "SecurePass123" + confirmPassword: "SecurePass123" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_invalid_username) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for RegistrationInput +username + Value error, Username must be alphanumeric [type=value_error, input_value='john@123', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/value_error\ +""") + + # Test with weak password + mutation_weak_password = """ + mutation { + register(input: { + username: "john123" + password: "weak" + confirmPassword: "weak" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_weak_password) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for RegistrationInput +password + Value error, Password must be at least 8 characters long [type=value_error, input_value='weak', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/value_error\ +""") + + # Test with mismatched passwords + mutation_mismatch_password = """ + mutation { + register(input: { + username: "john123" + password: "SecurePass123" + confirmPassword: "DifferentPass123" + age: 25 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_mismatch_password) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for RegistrationInput +confirm_password + Value error, Passwords do not match [type=value_error, input_value='DifferentPass123', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/value_error\ +""") + + # Test with underage user + mutation_underage = """ + mutation { + register(input: { + username: "kid123" + password: "SecurePass123" + confirmPassword: "SecurePass123" + age: 10 + }) { + username + age + } + } + """ + + result = schema.execute_sync(mutation_underage) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for RegistrationInput +age + Value error, Must be at least 13 years old [type=value_error, input_value=10, input_type=int] + For further information visit https://errors.pydantic.dev/2.11/v/value_error\ +""") + + +def test_input_type_with_optional_fields_and_validation(): + """Test input types with optional fields and validation.""" + + @strawberry.pydantic.input + class UpdateProfileInput(pydantic.BaseModel): + bio: Annotated[Optional[str], pydantic.Field(None, max_length=200)] + website: Annotated[Optional[str], pydantic.Field(None, pattern=r"^https?://.*")] + age: Annotated[Optional[int], pydantic.Field(None, ge=0, le=150)] + + @strawberry.pydantic.type + class Profile(pydantic.BaseModel): + bio: Optional[str] = None + website: Optional[str] = None + age: Optional[int] = None + + @strawberry.type + class Mutation: + @strawberry.field + def update_profile(self, input: UpdateProfileInput) -> Profile: + return Profile(bio=input.bio, website=input.website, age=input.age) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + # Test with all valid optional fields + mutation_all_fields = """ + mutation { + updateProfile(input: { + bio: "Software developer" + website: "https://example.com" + age: 30 + }) { + bio + website + age + } + } + """ + + result = schema.execute_sync(mutation_all_fields) + + assert not result.errors + assert result.data == snapshot( + { + "updateProfile": { + "bio": "Software developer", + "website": "https://example.com", + "age": 30, + } + } + ) + + # Test with only some fields + mutation_partial = """ + mutation { + updateProfile(input: { + bio: "Just a bio" + }) { + bio + website + age + } + } + """ + + result = schema.execute_sync(mutation_partial) + + assert not result.errors + assert result.data == snapshot( + {"updateProfile": {"bio": "Just a bio", "website": None, "age": None}} + ) + + # Test with invalid website URL + mutation_invalid_url = """ + mutation { + updateProfile(input: { + website: "not-a-url" + }) { + bio + website + age + } + } + """ + + result = schema.execute_sync(mutation_invalid_url) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for UpdateProfileInput +website + String should match pattern '^https?://.*' [type=string_pattern_mismatch, input_value='not-a-url', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/string_pattern_mismatch\ +""") + + # Test with bio too long + long_bio = "x" * 201 + mutation_long_bio = f""" + mutation {{ + updateProfile(input: {{ + bio: "{long_bio}" + }}) {{ + bio + website + age + }} + }} + """ + + result = schema.execute_sync(mutation_long_bio) + assert result.errors is not None + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot("""\ +1 validation error for UpdateProfileInput +bio + String should have at most 200 characters [type=string_too_long, input_value='xxxxxxxxxxxxxxxxxxxxxxxx...xxxxxxxxxxxxxxxxxxxxxxx', input_type=str] + For further information visit https://errors.pydantic.dev/2.11/v/string_too_long\ +""") diff --git a/tests/pydantic/test_interface.py b/tests/pydantic/test_interface.py new file mode 100644 index 0000000000..7436789fb1 --- /dev/null +++ b/tests/pydantic/test_interface.py @@ -0,0 +1,53 @@ +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.types.base import get_object_definition + + +def test_basic_interface_type(): + """Test that @strawberry.pydantic.interface works.""" + + @strawberry.pydantic.interface + class Node(pydantic.BaseModel): + id: str + + definition = get_object_definition(Node, strict=True) + + assert definition.name == "Node" + assert definition.is_interface is True + assert len(definition.fields) == 1 + + +def test_pydantic_interface_basic(): + """Test basic Pydantic interface functionality.""" + + @strawberry.pydantic.interface + class Node(pydantic.BaseModel): + id: str + + @strawberry.pydantic.type + class User(Node): + name: str + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id="user_1", name="John") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot({"getUser": {"id": "user_1", "name": "John"}}) diff --git a/tests/pydantic/test_nested_types.py b/tests/pydantic/test_nested_types.py new file mode 100644 index 0000000000..10c4504286 --- /dev/null +++ b/tests/pydantic/test_nested_types.py @@ -0,0 +1,184 @@ +""" +Nested type tests for Pydantic integration. + +These tests verify that nested Pydantic types work correctly in GraphQL. +""" + +from typing import Optional + +from inline_snapshot import snapshot + +import pydantic +import strawberry + + +def test_nested_pydantic_types(): + """Test nested Pydantic types in queries.""" + + @strawberry.pydantic.type + class Address(pydantic.BaseModel): + street: str + city: str + zipcode: str + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + address: Address + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + name="John", + age=30, + address=Address(street="123 Main St", city="Anytown", zipcode="12345"), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + address { + street + city + zipcode + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "getUser": { + "name": "John", + "age": 30, + "address": { + "street": "123 Main St", + "city": "Anytown", + "zipcode": "12345", + }, + } + } + ) + + +def test_list_of_pydantic_types(): + """Test lists of Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_users(self) -> list[User]: + return [ + User(name="John", age=30), + User(name="Jane", age=25), + User(name="Bob", age=35), + ] + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUsers { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "getUsers": [ + {"name": "John", "age": 30}, + {"name": "Jane", "age": 25}, + {"name": "Bob", "age": 35}, + ] + } + ) + + +def test_complex_pydantic_types_execution(): + """Test complex Pydantic types with various field types.""" + + @strawberry.pydantic.type + class Profile(pydantic.BaseModel): + bio: Optional[str] = None + website: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + email: str + is_active: bool + tags: list[str] = [] + profile: Optional[Profile] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User( + id=1, + name="John Doe", + email="john@example.com", + is_active=True, + tags=["developer", "python", "graphql"], + profile=Profile( + bio="Software developer", website="https://johndoe.com" + ), + ) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + id + name + email + isActive + tags + profile { + bio + website + } + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + { + "getUser": { + "id": 1, + "name": "John Doe", + "email": "john@example.com", + "isActive": True, + "tags": ["developer", "python", "graphql"], + "profile": { + "bio": "Software developer", + "website": "https://johndoe.com", + }, + } + } + ) diff --git a/tests/pydantic/test_private.py b/tests/pydantic/test_private.py new file mode 100644 index 0000000000..619f725343 --- /dev/null +++ b/tests/pydantic/test_private.py @@ -0,0 +1,126 @@ +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.types.base import get_object_definition + + +def test_strawberry_private_fields(): + """Test that strawberry.Private fields are excluded from the GraphQL schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + password: strawberry.Private[str] + + definition = get_object_definition(User, strict=True) + assert definition.name == "User" + + # Should have three fields (id, name, age) - password should be excluded + assert len(definition.fields) == 3 + + field_names = {f.python_name for f in definition.fields} + assert field_names == {"id", "name", "age"} + + # password field should not be in the GraphQL schema + assert "password" not in field_names + + # But the python object should still have the password field + user = User(id=1, name="John", age=30, password="secret") + assert user.id == 1 + assert user.name == "John" + assert user.age == 30 + assert user.password == "secret" + + +def test_strawberry_private_fields_access(): + """Test that strawberry.Private fields can be accessed in Python code.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + password: strawberry.Private[str] + + definition = get_object_definition(User, strict=True) + assert definition.name == "User" + + # Should have two fields (id, name) - password should be excluded + assert len(definition.fields) == 2 + + field_names = {f.python_name for f in definition.fields} + assert field_names == {"id", "name"} + + # Test that the private field is still accessible on the instance + user = User(id=1, name="John", password="secret") + assert user.id == 1 + assert user.name == "John" + assert user.password == "secret" + + # Test that we can use the private field in Python logic + def has_password(user: User) -> bool: + return bool(user.password) + + assert has_password(user) is True + + user_no_password = User(id=2, name="Jane", password="") + assert has_password(user_no_password) is False + + +def test_strawberry_private_fields_not_in_schema(): + """Test that strawberry.Private fields are not exposed in GraphQL schema.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + password: strawberry.Private[str] + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(id=1, name="John", password="secret123") + + schema = strawberry.Schema(query=Query) + + # Check that password field is not in the schema + schema_str = str(schema) + assert "password" not in schema_str + assert "id: Int!" in schema_str + assert "name: String!" in schema_str + + # Test that we can query the exposed fields + query = """ + query { + getUser { + id + name + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot({"getUser": {"id": 1, "name": "John"}}) + + # Test that querying the private field fails + query_with_private = """ + query { + getUser { + id + name + password + } + } + """ + + result = schema.execute_sync(query_with_private) + assert result.errors + assert len(result.errors) == 1 + assert result.errors[0].message == snapshot( + "Cannot query field 'password' on type 'User'." + ) diff --git a/tests/pydantic/test_queries_mutations.py b/tests/pydantic/test_queries_mutations.py new file mode 100644 index 0000000000..aafd14419c --- /dev/null +++ b/tests/pydantic/test_queries_mutations.py @@ -0,0 +1,187 @@ +""" +Query and mutation execution tests for Pydantic integration. + +These tests verify that Pydantic models work correctly in GraphQL queries and mutations. +""" + +from typing import Optional + +from inline_snapshot import snapshot + +import pydantic +import strawberry + + +def test_basic_query_execution(): + """Test basic query execution with Pydantic types.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + age: int + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", age=30) + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot({"getUser": {"name": "John", "age": 30}}) + + +def test_query_with_optional_fields(): + """Test query execution with optional fields.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + name: str + email: Optional[str] = None + age: Optional[int] = None + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(name="John", email="john@example.com") + + schema = strawberry.Schema(query=Query) + + query = """ + query { + getUser { + name + email + age + } + } + """ + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data == snapshot( + {"getUser": {"name": "John", "email": "john@example.com", "age": None}} + ) + + +def test_mutation_with_input_types(): + """Test mutation execution with Pydantic input types.""" + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + name: str + age: int + email: Optional[str] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + email: Optional[str] = None + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(id=1, name=input.name, age=input.age, email=input.email) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + createUser(input: { + name: "Alice" + age: 25 + email: "alice@example.com" + }) { + id + name + age + email + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == snapshot( + { + "createUser": { + "id": 1, + "name": "Alice", + "age": 25, + "email": "alice@example.com", + } + } + ) + + +def test_mutation_with_partial_input(): + """Test mutation with partial input (optional fields).""" + + @strawberry.pydantic.input + class UpdateUserInput(pydantic.BaseModel): + name: Optional[str] = None + age: Optional[int] = None + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + id: int + name: str + age: int + + @strawberry.type + class Mutation: + @strawberry.field + def update_user(self, id: int, input: UpdateUserInput) -> User: + # Simulate updating a user + return User(id=id, name=input.name or "Default Name", age=input.age or 18) + + @strawberry.type + class Query: + @strawberry.field + def dummy(self) -> str: + return "dummy" + + schema = strawberry.Schema(query=Query, mutation=Mutation) + + mutation = """ + mutation { + updateUser(id: 1, input: { + name: "Updated Name" + }) { + id + name + age + } + } + """ + + result = schema.execute_sync(mutation) + + assert not result.errors + assert result.data == snapshot( + {"updateUser": {"id": 1, "name": "Updated Name", "age": 18}} + ) diff --git a/tests/pydantic/test_type.py b/tests/pydantic/test_type.py new file mode 100644 index 0000000000..f459b5e3c4 --- /dev/null +++ b/tests/pydantic/test_type.py @@ -0,0 +1,136 @@ +from typing import Optional + +from inline_snapshot import snapshot + +import pydantic +import strawberry +from strawberry.types.base import ( + StrawberryOptional, + get_object_definition, +) + + +def test_basic_type_includes_all_fields(): + """Test that @strawberry.pydantic.type includes all fields from the model.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + password: Optional[str] + + definition = get_object_definition(User, strict=True) + assert definition.name == "User" + + # Should have two fields + assert len(definition.fields) == 2 + + # Find fields by name + age_field = next(f for f in definition.fields if f.python_name == "age") + password_field = next(f for f in definition.fields if f.python_name == "password") + + assert age_field.python_name == "age" + assert age_field.graphql_name is None + assert age_field.type is int + + assert password_field.python_name == "password" + assert password_field.graphql_name is None + assert isinstance(password_field.type, StrawberryOptional) + assert password_field.type.of_type is str + + +def test_basic_type_with_name_override(): + """Test that @strawberry.pydantic.type with name parameter works.""" + + @strawberry.pydantic.type(name="CustomUser") + class User(pydantic.BaseModel): + age: int + + definition = get_object_definition(User, strict=True) + assert definition.name == "CustomUser" + + +def test_basic_type_with_description(): + """Test that @strawberry.pydantic.type with description parameter works.""" + + @strawberry.pydantic.type(description="A user model") + class User(pydantic.BaseModel): + age: int + + definition = get_object_definition(User, strict=True) + assert definition.description == "A user model" + + +def test_is_type_of_method(): + """Test that is_type_of method is added for proper type resolution.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + name: str + + # Check that is_type_of method exists + assert hasattr(User, "is_type_of") + assert callable(User.is_type_of) + + # Test type checking + user_instance = User(age=25, name="John") + assert User.is_type_of(user_instance, None) is True + + # Test with different type + class Other: + pass + + other_instance = Other() + assert User.is_type_of(other_instance, None) is False + + +def test_schema_generation(): + """Test that the decorated models work in schema generation.""" + + @strawberry.pydantic.type + class User(pydantic.BaseModel): + age: int + name: str + + @strawberry.pydantic.input + class CreateUserInput(pydantic.BaseModel): + age: int + name: str + + @strawberry.type + class Query: + @strawberry.field + def get_user(self) -> User: + return User(age=25, name="John") + + @strawberry.type + class Mutation: + @strawberry.field + def create_user(self, input: CreateUserInput) -> User: + return User(age=input.age, name=input.name) + + # Test that schema can be created successfully + schema = strawberry.Schema(query=Query, mutation=Mutation) + assert schema is not None + + assert str(schema) == snapshot( + """\ +input CreateUserInput { + age: Int! + name: String! +} + +type Mutation { + createUser(input: CreateUserInput!): User! +} + +type Query { + getUser: User! +} + +type User { + age: Int! + name: String! +}\ +""" + )