diff --git a/pydantic_xml/compat.py b/pydantic_xml/compat.py new file mode 100644 index 0000000..57c88e8 --- /dev/null +++ b/pydantic_xml/compat.py @@ -0,0 +1,16 @@ +""" +pydantic compatibility module. +""" + +import pydantic as pd +from pydantic._internal._model_construction import ModelMetaclass # noqa +from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa + +PYDANTIC_VERSION = tuple(map(int, pd.__version__.partition('+')[0].split('.'))) + + +def merge_field_infos(*field_infos: pd.fields.FieldInfo) -> pd.fields.FieldInfo: + if PYDANTIC_VERSION >= (2, 12, 0): + return pd.fields.FieldInfo._construct(field_infos) # type: ignore[attr-defined] + else: + return pd.fields.FieldInfo.merge_field_infos(*field_infos) diff --git a/pydantic_xml/fields.py b/pydantic_xml/fields.py index 8880172..c0cabd3 100644 --- a/pydantic_xml/fields.py +++ b/pydantic_xml/fields.py @@ -1,13 +1,11 @@ import dataclasses as dc import typing -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Dict, Optional, Union import pydantic as pd import pydantic_core as pdc -from pydantic._internal._model_construction import ModelMetaclass # noqa -from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa -from . import config, model, utils +from . import compat, config, model, utils from .typedefs import EntityLocation from .utils import NsMap @@ -17,6 +15,7 @@ 'computed_element', 'computed_entity', 'element', + 'extract_field_xml_entity_info', 'wrapped', 'xml_field_serializer', 'xml_field_validator', @@ -37,83 +36,79 @@ class XmlEntityInfoP(typing.Protocol): wrapped: Optional['XmlEntityInfoP'] -class XmlEntityInfo(pd.fields.FieldInfo, XmlEntityInfoP): +@dc.dataclass(frozen=True) +class XmlEntityInfo(XmlEntityInfoP): """ Field xml meta-information. """ - __slots__ = ('location', 'path', 'ns', 'nsmap', 'nillable', 'wrapped') + location: Optional[EntityLocation] + path: Optional[str] = None + ns: Optional[str] = None + nsmap: Optional[NsMap] = None + nillable: Optional[bool] = None + wrapped: Optional[XmlEntityInfoP] = None + + def __post_init__(self) -> None: + if config.REGISTER_NS_PREFIXES and self.nsmap: + utils.register_nsmap(self.nsmap) @staticmethod - def merge_field_infos(*field_infos: pd.fields.FieldInfo, **overrides: Any) -> pd.fields.FieldInfo: - location, path, ns, nsmap, nillable, wrapped = None, None, None, None, None, None - - for field_info in field_infos: - if isinstance(field_info, XmlEntityInfo): - location = field_info.location if field_info.location is not None else location - path = field_info.path if field_info.path is not None else path - ns = field_info.ns if field_info.ns is not None else ns - nsmap = field_info.nsmap if field_info.nsmap is not None else nsmap - nillable = field_info.nillable if field_info.nillable is not None else nillable - wrapped = field_info.wrapped if field_info.wrapped is not None else wrapped - - field_info = pd.fields.FieldInfo.merge_field_infos(*field_infos, **overrides) - - xml_entity_info = XmlEntityInfo( - location, + def merge(*entity_infos: XmlEntityInfoP) -> 'XmlEntityInfo': + location: Optional[EntityLocation] = None + path: Optional[str] = None + ns: Optional[str] = None + nsmap: Optional[NsMap] = None + nillable: Optional[bool] = None + wrapped: Optional[XmlEntityInfoP] = None + + for entity_info in entity_infos: + if entity_info.location is not None: + location = entity_info.location + if entity_info.wrapped is not None: + wrapped = entity_info.wrapped + if entity_info.path is not None: + path = entity_info.path + if entity_info.ns is not None: + ns = entity_info.ns + if entity_info.nsmap is not None: + nsmap = utils.merge_nsmaps(entity_info.nsmap, nsmap) + if entity_info.nillable is not None: + nillable = entity_info.nillable + + return XmlEntityInfo( + location=location, path=path, ns=ns, nsmap=nsmap, nillable=nillable, - wrapped=wrapped if isinstance(wrapped, XmlEntityInfo) else None, - **field_info._attributes_set, + wrapped=wrapped, ) - xml_entity_info.metadata = field_info.metadata - - return xml_entity_info - - def __init__( - self, - location: Optional[EntityLocation], - /, - path: Optional[str] = None, - ns: Optional[str] = None, - nsmap: Optional[NsMap] = None, - nillable: Optional[bool] = None, - wrapped: Optional[pd.fields.FieldInfo] = None, - **kwargs: Any, - ): - wrapped_metadata: list[Any] = [] - if wrapped is not None: - # copy arguments from the wrapped entity to let pydantic know how to process the field - for entity_field_name in utils.get_slots(wrapped): - if entity_field_name in pd.fields._FIELD_ARG_NAMES: - kwargs[entity_field_name] = getattr(wrapped, entity_field_name) - wrapped_metadata = wrapped.metadata - - if kwargs.get('serialization_alias') is None: - kwargs['serialization_alias'] = kwargs.get('alias') - - if kwargs.get('validation_alias') is None: - kwargs['validation_alias'] = kwargs.get('alias') - - super().__init__(**kwargs) - self.metadata.extend(wrapped_metadata) - - self.location = location - self.path = path - self.ns = ns - self.nsmap = nsmap - self.nillable = nillable - self.wrapped: Optional[XmlEntityInfoP] = wrapped if isinstance(wrapped, XmlEntityInfo) else None - - if config.REGISTER_NS_PREFIXES and nsmap: - utils.register_nsmap(nsmap) + + +def extract_field_xml_entity_info(field_info: pd.fields.FieldInfo) -> Optional[XmlEntityInfoP]: + entity_info_list = list(filter(lambda meta: isinstance(meta, XmlEntityInfo), field_info.metadata)) + if entity_info_list: + entity_info = XmlEntityInfo.merge(*entity_info_list) + else: + entity_info = None + + return entity_info _Unset: Any = pdc.PydanticUndefined +def prepare_field_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + if kwargs.get('serialization_alias') in (None, pdc.PydanticUndefined): + kwargs['serialization_alias'] = kwargs.get('alias') + + if kwargs.get('validation_alias') in (None, pdc.PydanticUndefined): + kwargs['validation_alias'] = kwargs.get('alias') + + return kwargs + + def attr( name: Optional[str] = None, ns: Optional[str] = None, @@ -132,12 +127,15 @@ def attr( :param kwargs: pydantic field arguments. See :py:class:`pydantic.Field` """ - return XmlEntityInfo( - EntityLocation.ATTRIBUTE, - path=name, ns=ns, default=default, default_factory=default_factory, - **kwargs, + kwargs = prepare_field_kwargs(kwargs) + + field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs) + field_info.metadata.append( + XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns), ) + return field_info + def element( tag: Optional[str] = None, @@ -161,12 +159,15 @@ def element( :param kwargs: pydantic field arguments. See :py:class:`pydantic.Field` """ - return XmlEntityInfo( - EntityLocation.ELEMENT, - path=tag, ns=ns, nsmap=nsmap, nillable=nillable, default=default, default_factory=default_factory, - **kwargs, + kwargs = prepare_field_kwargs(kwargs) + + field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs) + field_info.metadata.append( + XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable), ) + return field_info + def wrapped( path: str, @@ -190,12 +191,22 @@ def wrapped( :param kwargs: pydantic field arguments. See :py:class:`pydantic.Field` """ - return XmlEntityInfo( - EntityLocation.WRAPPED, - path=path, ns=ns, nsmap=nsmap, wrapped=entity, default=default, default_factory=default_factory, - **kwargs, + if entity is None: + wrapped_entity_info = None + field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs) + else: + wrapped_entity_info = extract_field_xml_entity_info(entity) + field_info = compat.merge_field_infos( + pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs), + entity, + ) + + field_info.metadata.append( + XmlEntityInfo(EntityLocation.WRAPPED, path=path, ns=ns, nsmap=nsmap, wrapped=wrapped_entity_info), ) + return field_info + @dc.dataclass class ComputedXmlEntityInfo(pd.fields.ComputedFieldInfo, XmlEntityInfoP): @@ -293,7 +304,7 @@ def computed_element( def xml_field_validator( - field: str, /, *fields: str + field: str, /, *fields: str, ) -> 'Callable[[model.ValidatorFuncT[model.ModelT]], model.ValidatorFuncT[model.ModelT]]': """ Marks the method as a field xml validator. @@ -312,7 +323,7 @@ def wrapper(func: model.ValidatorFuncT[model.ModelT]) -> model.ValidatorFuncT[mo def xml_field_serializer( - field: str, /, *fields: str + field: str, /, *fields: str, ) -> 'Callable[[model.SerializerFuncT[model.ModelT]], model.SerializerFuncT[model.ModelT]]': """ Marks the method as a field xml serializer. diff --git a/pydantic_xml/model.py b/pydantic_xml/model.py index 312efd2..1ee18ef 100644 --- a/pydantic_xml/model.py +++ b/pydantic_xml/model.py @@ -5,10 +5,9 @@ import pydantic_core as pdc import typing_extensions as te from pydantic import BaseModel, RootModel -from pydantic._internal._model_construction import ModelMetaclass # noqa -from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa from . import config, errors, utils +from .compat import ModelMetaclass, RootModelMetaclass from .element import SearchMode, XmlElementReader, XmlElementWriter from .element.native import ElementT, XmlElement, etree from .fields import XmlEntityInfo, XmlFieldSerializer, XmlFieldValidator, attr, element, wrapped diff --git a/pydantic_xml/serializers/factories/model.py b/pydantic_xml/serializers/factories/model.py index 140970f..39fcf4a 100644 --- a/pydantic_xml/serializers/factories/model.py +++ b/pydantic_xml/serializers/factories/model.py @@ -9,7 +9,7 @@ import pydantic_xml as pxml from pydantic_xml import errors, utils from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill -from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP +from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP, extract_field_xml_entity_info from pydantic_xml.serializers.serializer import SearchMode, Serializer from pydantic_xml.typedefs import EntityLocation, Location, NsMap from pydantic_xml.utils import QName, merge_nsmaps, select_ns @@ -79,15 +79,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> ' fields_validation_aliases[field_name] = validation_alias field_info = model_cls.model_fields[field_name] - if isinstance(field_info, pxml.model.XmlEntityInfo): - entity_info = field_info - else: - entity_info = None - field_ctx = ctx.child( field_name=field_name, field_alias=field_alias, - entity_info=entity_info, + entity_info=extract_field_xml_entity_info(field_info), ) fields_serializers[field_name] = Serializer.parse_core_schema(model_field['schema'], field_ctx) @@ -234,16 +229,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> ' assert issubclass(model_cls, pxml.BaseXmlModel), "model class must be a BaseXmlModel subclass" - entity_info: Optional[XmlEntityInfoP] field_info = model_cls.model_fields['root'] - if isinstance(field_info, pxml.model.XmlEntityInfo): - entity_info = field_info - else: - entity_info = None - field_ctx = ctx.child( field_name=None, - entity_info=entity_info, + entity_info=extract_field_xml_entity_info(field_info), ) root_serializer = Serializer.parse_core_schema(root_schema, field_ctx) diff --git a/tests/test_encoder.py b/tests/test_encoder.py index 949050f..d8bc011 100644 --- a/tests/test_encoder.py +++ b/tests/test_encoder.py @@ -282,9 +282,9 @@ def validate_model_before(cls, data: Dict[str, Any]) -> 'TestModel': } @model_validator(mode='after') - def validate_model_after(cls, obj: 'TestModel') -> 'TestModel': - obj.field1 = obj.field1.replace(tzinfo=dt.timezone.utc) - return obj + def validate_model_after(self) -> 'TestModel': + self.field1 = self.field1.replace(tzinfo=dt.timezone.utc) + return self @model_validator(mode='wrap') def validate_model_wrap(cls, obj: 'TestModel', handler: Callable) -> 'TestModel': diff --git a/tests/test_misc.py b/tests/test_misc.py index 826bbd8..c97ad12 100644 --- a/tests/test_misc.py +++ b/tests/test_misc.py @@ -6,7 +6,6 @@ from helpers import assert_xml_equal from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element, errors, wrapped -from pydantic_xml.fields import XmlEntityInfo def test_xml_declaration(): @@ -385,28 +384,25 @@ def validate_field(cls, v: str, info: pd.FieldValidationInfo): def test_field_info_merge(): from typing import Annotated - from annotated_types import Ge, Lt - class TestModel(BaseXmlModel, tag='root'): element1: Annotated[ int, pd.Field(ge=0), pd.Field(default=0, lt=100), - element(nillable=True), - ] = element(tag='elm', lt=10) + element(lt=5), + ] = element(tag='elm') field_info = TestModel.model_fields['element1'] - assert isinstance(field_info, XmlEntityInfo) - assert field_info.metadata == [Ge(ge=0), Lt(lt=10)] assert field_info.default == 0 - assert field_info.nillable == True - assert field_info.path == 'elm' TestModel.from_xml("0") with pytest.raises(pd.ValidationError): TestModel.from_xml("-1") + with pytest.raises(pd.ValidationError): + TestModel.from_xml("5") + def test_get_type_hints(): from typing import get_type_hints