Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pydantic_xml/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
pydantic compatibility module.
"""

import pydantic as pd
from pydantic._internal._model_construction import ModelMetaclass # noqa
from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa

PYDANTIC_VERSION = tuple(map(int, pd.__version__.partition('+')[0].split('.')))


def merge_field_infos(*field_infos: pd.fields.FieldInfo) -> pd.fields.FieldInfo:
if PYDANTIC_VERSION >= (2, 12, 0):
return pd.fields.FieldInfo._construct(field_infos) # type: ignore[attr-defined]
else:
return pd.fields.FieldInfo.merge_field_infos(*field_infos)
169 changes: 90 additions & 79 deletions pydantic_xml/fields.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import dataclasses as dc
import typing
from typing import Any, Callable, Optional, Union
from typing import Any, Callable, Dict, Optional, Union

import pydantic as pd
import pydantic_core as pdc
from pydantic._internal._model_construction import ModelMetaclass # noqa
from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa

from . import config, model, utils
from . import compat, config, model, utils
from .typedefs import EntityLocation
from .utils import NsMap

Expand All @@ -17,6 +15,7 @@
'computed_element',
'computed_entity',
'element',
'extract_field_xml_entity_info',
'wrapped',
'xml_field_serializer',
'xml_field_validator',
Expand All @@ -37,83 +36,79 @@ class XmlEntityInfoP(typing.Protocol):
wrapped: Optional['XmlEntityInfoP']


class XmlEntityInfo(pd.fields.FieldInfo, XmlEntityInfoP):
@dc.dataclass(frozen=True)
class XmlEntityInfo(XmlEntityInfoP):
"""
Field xml meta-information.
"""

__slots__ = ('location', 'path', 'ns', 'nsmap', 'nillable', 'wrapped')
location: Optional[EntityLocation]
path: Optional[str] = None
ns: Optional[str] = None
nsmap: Optional[NsMap] = None
nillable: Optional[bool] = None
wrapped: Optional[XmlEntityInfoP] = None

def __post_init__(self) -> None:
if config.REGISTER_NS_PREFIXES and self.nsmap:
utils.register_nsmap(self.nsmap)

@staticmethod
def merge_field_infos(*field_infos: pd.fields.FieldInfo, **overrides: Any) -> pd.fields.FieldInfo:
location, path, ns, nsmap, nillable, wrapped = None, None, None, None, None, None

for field_info in field_infos:
if isinstance(field_info, XmlEntityInfo):
location = field_info.location if field_info.location is not None else location
path = field_info.path if field_info.path is not None else path
ns = field_info.ns if field_info.ns is not None else ns
nsmap = field_info.nsmap if field_info.nsmap is not None else nsmap
nillable = field_info.nillable if field_info.nillable is not None else nillable
wrapped = field_info.wrapped if field_info.wrapped is not None else wrapped

field_info = pd.fields.FieldInfo.merge_field_infos(*field_infos, **overrides)

xml_entity_info = XmlEntityInfo(
location,
def merge(*entity_infos: XmlEntityInfoP) -> 'XmlEntityInfo':
location: Optional[EntityLocation] = None
path: Optional[str] = None
ns: Optional[str] = None
nsmap: Optional[NsMap] = None
nillable: Optional[bool] = None
wrapped: Optional[XmlEntityInfoP] = None

for entity_info in entity_infos:
if entity_info.location is not None:
location = entity_info.location
if entity_info.wrapped is not None:
wrapped = entity_info.wrapped
if entity_info.path is not None:
path = entity_info.path
if entity_info.ns is not None:
ns = entity_info.ns
if entity_info.nsmap is not None:
nsmap = utils.merge_nsmaps(entity_info.nsmap, nsmap)
if entity_info.nillable is not None:
nillable = entity_info.nillable

return XmlEntityInfo(
location=location,
path=path,
ns=ns,
nsmap=nsmap,
nillable=nillable,
wrapped=wrapped if isinstance(wrapped, XmlEntityInfo) else None,
**field_info._attributes_set,
wrapped=wrapped,
)
xml_entity_info.metadata = field_info.metadata

return xml_entity_info

def __init__(
self,
location: Optional[EntityLocation],
/,
path: Optional[str] = None,
ns: Optional[str] = None,
nsmap: Optional[NsMap] = None,
nillable: Optional[bool] = None,
wrapped: Optional[pd.fields.FieldInfo] = None,
**kwargs: Any,
):
wrapped_metadata: list[Any] = []
if wrapped is not None:
# copy arguments from the wrapped entity to let pydantic know how to process the field
for entity_field_name in utils.get_slots(wrapped):
if entity_field_name in pd.fields._FIELD_ARG_NAMES:
kwargs[entity_field_name] = getattr(wrapped, entity_field_name)
wrapped_metadata = wrapped.metadata

if kwargs.get('serialization_alias') is None:
kwargs['serialization_alias'] = kwargs.get('alias')

if kwargs.get('validation_alias') is None:
kwargs['validation_alias'] = kwargs.get('alias')

super().__init__(**kwargs)
self.metadata.extend(wrapped_metadata)

self.location = location
self.path = path
self.ns = ns
self.nsmap = nsmap
self.nillable = nillable
self.wrapped: Optional[XmlEntityInfoP] = wrapped if isinstance(wrapped, XmlEntityInfo) else None

if config.REGISTER_NS_PREFIXES and nsmap:
utils.register_nsmap(nsmap)


def extract_field_xml_entity_info(field_info: pd.fields.FieldInfo) -> Optional[XmlEntityInfoP]:
entity_info_list = list(filter(lambda meta: isinstance(meta, XmlEntityInfo), field_info.metadata))
if entity_info_list:
entity_info = XmlEntityInfo.merge(*entity_info_list)
else:
entity_info = None

return entity_info


_Unset: Any = pdc.PydanticUndefined


def prepare_field_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
if kwargs.get('serialization_alias') in (None, pdc.PydanticUndefined):
kwargs['serialization_alias'] = kwargs.get('alias')

if kwargs.get('validation_alias') in (None, pdc.PydanticUndefined):
kwargs['validation_alias'] = kwargs.get('alias')

return kwargs


def attr(
name: Optional[str] = None,
ns: Optional[str] = None,
Expand All @@ -132,12 +127,15 @@ def attr(
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(
EntityLocation.ATTRIBUTE,
path=name, ns=ns, default=default, default_factory=default_factory,
**kwargs,
kwargs = prepare_field_kwargs(kwargs)

field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
field_info.metadata.append(
XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns),
)

return field_info


def element(
tag: Optional[str] = None,
Expand All @@ -161,12 +159,15 @@ def element(
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(
EntityLocation.ELEMENT,
path=tag, ns=ns, nsmap=nsmap, nillable=nillable, default=default, default_factory=default_factory,
**kwargs,
kwargs = prepare_field_kwargs(kwargs)

field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
field_info.metadata.append(
XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable),
)

return field_info


def wrapped(
path: str,
Expand All @@ -190,12 +191,22 @@ def wrapped(
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""

return XmlEntityInfo(
EntityLocation.WRAPPED,
path=path, ns=ns, nsmap=nsmap, wrapped=entity, default=default, default_factory=default_factory,
**kwargs,
if entity is None:
wrapped_entity_info = None
field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
else:
wrapped_entity_info = extract_field_xml_entity_info(entity)
field_info = compat.merge_field_infos(
pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs),
entity,
)

field_info.metadata.append(
XmlEntityInfo(EntityLocation.WRAPPED, path=path, ns=ns, nsmap=nsmap, wrapped=wrapped_entity_info),
)

return field_info


@dc.dataclass
class ComputedXmlEntityInfo(pd.fields.ComputedFieldInfo, XmlEntityInfoP):
Expand Down Expand Up @@ -293,7 +304,7 @@ def computed_element(


def xml_field_validator(
field: str, /, *fields: str
field: str, /, *fields: str,
) -> 'Callable[[model.ValidatorFuncT[model.ModelT]], model.ValidatorFuncT[model.ModelT]]':
"""
Marks the method as a field xml validator.
Expand All @@ -312,7 +323,7 @@ def wrapper(func: model.ValidatorFuncT[model.ModelT]) -> model.ValidatorFuncT[mo


def xml_field_serializer(
field: str, /, *fields: str
field: str, /, *fields: str,
) -> 'Callable[[model.SerializerFuncT[model.ModelT]], model.SerializerFuncT[model.ModelT]]':
"""
Marks the method as a field xml serializer.
Expand Down
3 changes: 1 addition & 2 deletions pydantic_xml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
import pydantic_core as pdc
import typing_extensions as te
from pydantic import BaseModel, RootModel
from pydantic._internal._model_construction import ModelMetaclass # noqa
from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa

from . import config, errors, utils
from .compat import ModelMetaclass, RootModelMetaclass
from .element import SearchMode, XmlElementReader, XmlElementWriter
from .element.native import ElementT, XmlElement, etree
from .fields import XmlEntityInfo, XmlFieldSerializer, XmlFieldValidator, attr, element, wrapped
Expand Down
17 changes: 3 additions & 14 deletions pydantic_xml/serializers/factories/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pydantic_xml as pxml
from pydantic_xml import errors, utils
from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill
from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP
from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP, extract_field_xml_entity_info
from pydantic_xml.serializers.serializer import SearchMode, Serializer
from pydantic_xml.typedefs import EntityLocation, Location, NsMap
from pydantic_xml.utils import QName, merge_nsmaps, select_ns
Expand Down Expand Up @@ -79,15 +79,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> '
fields_validation_aliases[field_name] = validation_alias

field_info = model_cls.model_fields[field_name]
if isinstance(field_info, pxml.model.XmlEntityInfo):
entity_info = field_info
else:
entity_info = None

field_ctx = ctx.child(
field_name=field_name,
field_alias=field_alias,
entity_info=entity_info,
entity_info=extract_field_xml_entity_info(field_info),
)
fields_serializers[field_name] = Serializer.parse_core_schema(model_field['schema'], field_ctx)

Expand Down Expand Up @@ -234,16 +229,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> '

assert issubclass(model_cls, pxml.BaseXmlModel), "model class must be a BaseXmlModel subclass"

entity_info: Optional[XmlEntityInfoP]
field_info = model_cls.model_fields['root']
if isinstance(field_info, pxml.model.XmlEntityInfo):
entity_info = field_info
else:
entity_info = None

field_ctx = ctx.child(
field_name=None,
entity_info=entity_info,
entity_info=extract_field_xml_entity_info(field_info),
)
root_serializer = Serializer.parse_core_schema(root_schema, field_ctx)

Expand Down
6 changes: 3 additions & 3 deletions tests/test_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,9 @@ def validate_model_before(cls, data: Dict[str, Any]) -> 'TestModel':
}

@model_validator(mode='after')
def validate_model_after(cls, obj: 'TestModel') -> 'TestModel':
obj.field1 = obj.field1.replace(tzinfo=dt.timezone.utc)
return obj
def validate_model_after(self) -> 'TestModel':
self.field1 = self.field1.replace(tzinfo=dt.timezone.utc)
return self

@model_validator(mode='wrap')
def validate_model_wrap(cls, obj: 'TestModel', handler: Callable) -> 'TestModel':
Expand Down
14 changes: 5 additions & 9 deletions tests/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from helpers import assert_xml_equal

from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element, errors, wrapped
from pydantic_xml.fields import XmlEntityInfo


def test_xml_declaration():
Expand Down Expand Up @@ -385,28 +384,25 @@ def validate_field(cls, v: str, info: pd.FieldValidationInfo):
def test_field_info_merge():
from typing import Annotated

from annotated_types import Ge, Lt

class TestModel(BaseXmlModel, tag='root'):
element1: Annotated[
int,
pd.Field(ge=0),
pd.Field(default=0, lt=100),
element(nillable=True),
] = element(tag='elm', lt=10)
element(lt=5),
] = element(tag='elm')

field_info = TestModel.model_fields['element1']
assert isinstance(field_info, XmlEntityInfo)
assert field_info.metadata == [Ge(ge=0), Lt(lt=10)]
assert field_info.default == 0
assert field_info.nillable == True
assert field_info.path == 'elm'

TestModel.from_xml("<root><elm>0</elm></root>")

with pytest.raises(pd.ValidationError):
TestModel.from_xml("<root><elm>-1</elm></root>")

with pytest.raises(pd.ValidationError):
TestModel.from_xml("<root><elm>5</elm></root>")


def test_get_type_hints():
from typing import get_type_hints
Expand Down