diff --git a/pydantic_xml/compat.py b/pydantic_xml/compat.py
new file mode 100644
index 0000000..57c88e8
--- /dev/null
+++ b/pydantic_xml/compat.py
@@ -0,0 +1,16 @@
+"""
+pydantic compatibility module.
+"""
+
+import pydantic as pd
+from pydantic._internal._model_construction import ModelMetaclass # noqa
+from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa
+
+PYDANTIC_VERSION = tuple(map(int, pd.__version__.partition('+')[0].split('.')))
+
+
+def merge_field_infos(*field_infos: pd.fields.FieldInfo) -> pd.fields.FieldInfo:
+ if PYDANTIC_VERSION >= (2, 12, 0):
+ return pd.fields.FieldInfo._construct(field_infos) # type: ignore[attr-defined]
+ else:
+ return pd.fields.FieldInfo.merge_field_infos(*field_infos)
diff --git a/pydantic_xml/fields.py b/pydantic_xml/fields.py
index 8880172..c0cabd3 100644
--- a/pydantic_xml/fields.py
+++ b/pydantic_xml/fields.py
@@ -1,13 +1,11 @@
import dataclasses as dc
import typing
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, Dict, Optional, Union
import pydantic as pd
import pydantic_core as pdc
-from pydantic._internal._model_construction import ModelMetaclass # noqa
-from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa
-from . import config, model, utils
+from . import compat, config, model, utils
from .typedefs import EntityLocation
from .utils import NsMap
@@ -17,6 +15,7 @@
'computed_element',
'computed_entity',
'element',
+ 'extract_field_xml_entity_info',
'wrapped',
'xml_field_serializer',
'xml_field_validator',
@@ -37,83 +36,79 @@ class XmlEntityInfoP(typing.Protocol):
wrapped: Optional['XmlEntityInfoP']
-class XmlEntityInfo(pd.fields.FieldInfo, XmlEntityInfoP):
+@dc.dataclass(frozen=True)
+class XmlEntityInfo(XmlEntityInfoP):
"""
Field xml meta-information.
"""
- __slots__ = ('location', 'path', 'ns', 'nsmap', 'nillable', 'wrapped')
+ location: Optional[EntityLocation]
+ path: Optional[str] = None
+ ns: Optional[str] = None
+ nsmap: Optional[NsMap] = None
+ nillable: Optional[bool] = None
+ wrapped: Optional[XmlEntityInfoP] = None
+
+ def __post_init__(self) -> None:
+ if config.REGISTER_NS_PREFIXES and self.nsmap:
+ utils.register_nsmap(self.nsmap)
@staticmethod
- def merge_field_infos(*field_infos: pd.fields.FieldInfo, **overrides: Any) -> pd.fields.FieldInfo:
- location, path, ns, nsmap, nillable, wrapped = None, None, None, None, None, None
-
- for field_info in field_infos:
- if isinstance(field_info, XmlEntityInfo):
- location = field_info.location if field_info.location is not None else location
- path = field_info.path if field_info.path is not None else path
- ns = field_info.ns if field_info.ns is not None else ns
- nsmap = field_info.nsmap if field_info.nsmap is not None else nsmap
- nillable = field_info.nillable if field_info.nillable is not None else nillable
- wrapped = field_info.wrapped if field_info.wrapped is not None else wrapped
-
- field_info = pd.fields.FieldInfo.merge_field_infos(*field_infos, **overrides)
-
- xml_entity_info = XmlEntityInfo(
- location,
+ def merge(*entity_infos: XmlEntityInfoP) -> 'XmlEntityInfo':
+ location: Optional[EntityLocation] = None
+ path: Optional[str] = None
+ ns: Optional[str] = None
+ nsmap: Optional[NsMap] = None
+ nillable: Optional[bool] = None
+ wrapped: Optional[XmlEntityInfoP] = None
+
+ for entity_info in entity_infos:
+ if entity_info.location is not None:
+ location = entity_info.location
+ if entity_info.wrapped is not None:
+ wrapped = entity_info.wrapped
+ if entity_info.path is not None:
+ path = entity_info.path
+ if entity_info.ns is not None:
+ ns = entity_info.ns
+ if entity_info.nsmap is not None:
+ nsmap = utils.merge_nsmaps(entity_info.nsmap, nsmap)
+ if entity_info.nillable is not None:
+ nillable = entity_info.nillable
+
+ return XmlEntityInfo(
+ location=location,
path=path,
ns=ns,
nsmap=nsmap,
nillable=nillable,
- wrapped=wrapped if isinstance(wrapped, XmlEntityInfo) else None,
- **field_info._attributes_set,
+ wrapped=wrapped,
)
- xml_entity_info.metadata = field_info.metadata
-
- return xml_entity_info
-
- def __init__(
- self,
- location: Optional[EntityLocation],
- /,
- path: Optional[str] = None,
- ns: Optional[str] = None,
- nsmap: Optional[NsMap] = None,
- nillable: Optional[bool] = None,
- wrapped: Optional[pd.fields.FieldInfo] = None,
- **kwargs: Any,
- ):
- wrapped_metadata: list[Any] = []
- if wrapped is not None:
- # copy arguments from the wrapped entity to let pydantic know how to process the field
- for entity_field_name in utils.get_slots(wrapped):
- if entity_field_name in pd.fields._FIELD_ARG_NAMES:
- kwargs[entity_field_name] = getattr(wrapped, entity_field_name)
- wrapped_metadata = wrapped.metadata
-
- if kwargs.get('serialization_alias') is None:
- kwargs['serialization_alias'] = kwargs.get('alias')
-
- if kwargs.get('validation_alias') is None:
- kwargs['validation_alias'] = kwargs.get('alias')
-
- super().__init__(**kwargs)
- self.metadata.extend(wrapped_metadata)
-
- self.location = location
- self.path = path
- self.ns = ns
- self.nsmap = nsmap
- self.nillable = nillable
- self.wrapped: Optional[XmlEntityInfoP] = wrapped if isinstance(wrapped, XmlEntityInfo) else None
-
- if config.REGISTER_NS_PREFIXES and nsmap:
- utils.register_nsmap(nsmap)
+
+
+def extract_field_xml_entity_info(field_info: pd.fields.FieldInfo) -> Optional[XmlEntityInfoP]:
+ entity_info_list = list(filter(lambda meta: isinstance(meta, XmlEntityInfo), field_info.metadata))
+ if entity_info_list:
+ entity_info = XmlEntityInfo.merge(*entity_info_list)
+ else:
+ entity_info = None
+
+ return entity_info
_Unset: Any = pdc.PydanticUndefined
+def prepare_field_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]:
+ if kwargs.get('serialization_alias') in (None, pdc.PydanticUndefined):
+ kwargs['serialization_alias'] = kwargs.get('alias')
+
+ if kwargs.get('validation_alias') in (None, pdc.PydanticUndefined):
+ kwargs['validation_alias'] = kwargs.get('alias')
+
+ return kwargs
+
+
def attr(
name: Optional[str] = None,
ns: Optional[str] = None,
@@ -132,12 +127,15 @@ def attr(
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""
- return XmlEntityInfo(
- EntityLocation.ATTRIBUTE,
- path=name, ns=ns, default=default, default_factory=default_factory,
- **kwargs,
+ kwargs = prepare_field_kwargs(kwargs)
+
+ field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
+ field_info.metadata.append(
+ XmlEntityInfo(EntityLocation.ATTRIBUTE, path=name, ns=ns),
)
+ return field_info
+
def element(
tag: Optional[str] = None,
@@ -161,12 +159,15 @@ def element(
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""
- return XmlEntityInfo(
- EntityLocation.ELEMENT,
- path=tag, ns=ns, nsmap=nsmap, nillable=nillable, default=default, default_factory=default_factory,
- **kwargs,
+ kwargs = prepare_field_kwargs(kwargs)
+
+ field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
+ field_info.metadata.append(
+ XmlEntityInfo(EntityLocation.ELEMENT, path=tag, ns=ns, nsmap=nsmap, nillable=nillable),
)
+ return field_info
+
def wrapped(
path: str,
@@ -190,12 +191,22 @@ def wrapped(
:param kwargs: pydantic field arguments. See :py:class:`pydantic.Field`
"""
- return XmlEntityInfo(
- EntityLocation.WRAPPED,
- path=path, ns=ns, nsmap=nsmap, wrapped=entity, default=default, default_factory=default_factory,
- **kwargs,
+ if entity is None:
+ wrapped_entity_info = None
+ field_info = pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs)
+ else:
+ wrapped_entity_info = extract_field_xml_entity_info(entity)
+ field_info = compat.merge_field_infos(
+ pd.fields.FieldInfo(default=default, default_factory=default_factory, **kwargs),
+ entity,
+ )
+
+ field_info.metadata.append(
+ XmlEntityInfo(EntityLocation.WRAPPED, path=path, ns=ns, nsmap=nsmap, wrapped=wrapped_entity_info),
)
+ return field_info
+
@dc.dataclass
class ComputedXmlEntityInfo(pd.fields.ComputedFieldInfo, XmlEntityInfoP):
@@ -293,7 +304,7 @@ def computed_element(
def xml_field_validator(
- field: str, /, *fields: str
+ field: str, /, *fields: str,
) -> 'Callable[[model.ValidatorFuncT[model.ModelT]], model.ValidatorFuncT[model.ModelT]]':
"""
Marks the method as a field xml validator.
@@ -312,7 +323,7 @@ def wrapper(func: model.ValidatorFuncT[model.ModelT]) -> model.ValidatorFuncT[mo
def xml_field_serializer(
- field: str, /, *fields: str
+ field: str, /, *fields: str,
) -> 'Callable[[model.SerializerFuncT[model.ModelT]], model.SerializerFuncT[model.ModelT]]':
"""
Marks the method as a field xml serializer.
diff --git a/pydantic_xml/model.py b/pydantic_xml/model.py
index 312efd2..1ee18ef 100644
--- a/pydantic_xml/model.py
+++ b/pydantic_xml/model.py
@@ -5,10 +5,9 @@
import pydantic_core as pdc
import typing_extensions as te
from pydantic import BaseModel, RootModel
-from pydantic._internal._model_construction import ModelMetaclass # noqa
-from pydantic.root_model import _RootModelMetaclass as RootModelMetaclass # noqa
from . import config, errors, utils
+from .compat import ModelMetaclass, RootModelMetaclass
from .element import SearchMode, XmlElementReader, XmlElementWriter
from .element.native import ElementT, XmlElement, etree
from .fields import XmlEntityInfo, XmlFieldSerializer, XmlFieldValidator, attr, element, wrapped
diff --git a/pydantic_xml/serializers/factories/model.py b/pydantic_xml/serializers/factories/model.py
index 140970f..39fcf4a 100644
--- a/pydantic_xml/serializers/factories/model.py
+++ b/pydantic_xml/serializers/factories/model.py
@@ -9,7 +9,7 @@
import pydantic_xml as pxml
from pydantic_xml import errors, utils
from pydantic_xml.element import XmlElementReader, XmlElementWriter, is_element_nill, make_element_nill
-from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP
+from pydantic_xml.fields import ComputedXmlEntityInfo, XmlEntityInfoP, extract_field_xml_entity_info
from pydantic_xml.serializers.serializer import SearchMode, Serializer
from pydantic_xml.typedefs import EntityLocation, Location, NsMap
from pydantic_xml.utils import QName, merge_nsmaps, select_ns
@@ -79,15 +79,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> '
fields_validation_aliases[field_name] = validation_alias
field_info = model_cls.model_fields[field_name]
- if isinstance(field_info, pxml.model.XmlEntityInfo):
- entity_info = field_info
- else:
- entity_info = None
-
field_ctx = ctx.child(
field_name=field_name,
field_alias=field_alias,
- entity_info=entity_info,
+ entity_info=extract_field_xml_entity_info(field_info),
)
fields_serializers[field_name] = Serializer.parse_core_schema(model_field['schema'], field_ctx)
@@ -234,16 +229,10 @@ def from_core_schema(cls, schema: pcs.ModelSchema, ctx: Serializer.Context) -> '
assert issubclass(model_cls, pxml.BaseXmlModel), "model class must be a BaseXmlModel subclass"
- entity_info: Optional[XmlEntityInfoP]
field_info = model_cls.model_fields['root']
- if isinstance(field_info, pxml.model.XmlEntityInfo):
- entity_info = field_info
- else:
- entity_info = None
-
field_ctx = ctx.child(
field_name=None,
- entity_info=entity_info,
+ entity_info=extract_field_xml_entity_info(field_info),
)
root_serializer = Serializer.parse_core_schema(root_schema, field_ctx)
diff --git a/tests/test_encoder.py b/tests/test_encoder.py
index 949050f..d8bc011 100644
--- a/tests/test_encoder.py
+++ b/tests/test_encoder.py
@@ -282,9 +282,9 @@ def validate_model_before(cls, data: Dict[str, Any]) -> 'TestModel':
}
@model_validator(mode='after')
- def validate_model_after(cls, obj: 'TestModel') -> 'TestModel':
- obj.field1 = obj.field1.replace(tzinfo=dt.timezone.utc)
- return obj
+ def validate_model_after(self) -> 'TestModel':
+ self.field1 = self.field1.replace(tzinfo=dt.timezone.utc)
+ return self
@model_validator(mode='wrap')
def validate_model_wrap(cls, obj: 'TestModel', handler: Callable) -> 'TestModel':
diff --git a/tests/test_misc.py b/tests/test_misc.py
index 826bbd8..c97ad12 100644
--- a/tests/test_misc.py
+++ b/tests/test_misc.py
@@ -6,7 +6,6 @@
from helpers import assert_xml_equal
from pydantic_xml import BaseXmlModel, RootXmlModel, attr, element, errors, wrapped
-from pydantic_xml.fields import XmlEntityInfo
def test_xml_declaration():
@@ -385,28 +384,25 @@ def validate_field(cls, v: str, info: pd.FieldValidationInfo):
def test_field_info_merge():
from typing import Annotated
- from annotated_types import Ge, Lt
-
class TestModel(BaseXmlModel, tag='root'):
element1: Annotated[
int,
pd.Field(ge=0),
pd.Field(default=0, lt=100),
- element(nillable=True),
- ] = element(tag='elm', lt=10)
+ element(lt=5),
+ ] = element(tag='elm')
field_info = TestModel.model_fields['element1']
- assert isinstance(field_info, XmlEntityInfo)
- assert field_info.metadata == [Ge(ge=0), Lt(lt=10)]
assert field_info.default == 0
- assert field_info.nillable == True
- assert field_info.path == 'elm'
TestModel.from_xml("0")
with pytest.raises(pd.ValidationError):
TestModel.from_xml("-1")
+ with pytest.raises(pd.ValidationError):
+ TestModel.from_xml("5")
+
def test_get_type_hints():
from typing import get_type_hints