14
14
# See the License for the specific language governing permissions and
15
15
# limitations under the License.
16
16
17
- from typing import Any , Optional
17
+ from typing import Any , Optional , get_args , get_origin , Union
18
18
19
19
from ._interfaces import JsonCodec , TypeHandler
20
20
from .orjson_codec import OrJsonCodec
@@ -152,40 +152,67 @@ def encode_parameters(self, *arguments) -> bytes:
152
152
def decode_return_value (self , data : bytes ) -> Any :
153
153
"""
154
154
Decode return value from JSON bytes and validate against self.return_type.
155
+ Supports nested generics and marker-wrapped types.
155
156
"""
156
- try :
157
- if not data :
158
- return None
157
+ if not data :
158
+ return None
159
159
160
- # Step 1: Decode JSON bytes into Python objects
161
- json_data = self ._decode_with_codecs (data )
160
+ # Step 1: Decode JSON bytes to Python object
161
+ json_data = self ._decode_with_codecs (data )
162
162
163
- # Step 2: Reconstruct objects (dataclasses, pydantic, enums, etc. )
164
- obj = self ._reconstruct_objects (json_data )
163
+ # Step 2: Reconstruct marker-based objects (datetime, UUID, set, frozenset, dataclass, pydantic )
164
+ obj = self ._reconstruct_objects (json_data )
165
165
166
- # Step 3: Strict return type validation
167
- if self .return_type :
168
- from typing import get_origin , get_args , Union
166
+ # Step 3: Validate type recursively
167
+ if self .return_type :
168
+ if not self ._validate_type (obj , self .return_type ):
169
+ raise DeserializationException (
170
+ f"Decoded object type { type (obj ).__name__ } does not match expected { self .return_type } "
171
+ )
169
172
170
- origin = get_origin (self .return_type )
171
- args = get_args (self .return_type )
173
+ return obj
172
174
173
- if origin is Union :
174
- if not any (isinstance (obj , arg ) for arg in args ):
175
- raise DeserializationException (
176
- f"Decoded object type { type (obj ).__name__ } not in expected Union types { args } "
177
- )
178
- else :
179
- if not isinstance (obj , self .return_type ):
180
- raise DeserializationException (
181
- f"Decoded object type { type (obj ).__name__ } "
182
- f"does not match expected return_type { self .return_type .__name__ } "
183
- )
175
+ def _validate_type (self , obj : Any , expected_type : type ) -> bool :
176
+ """
177
+ Recursively validate obj against expected_type.
178
+ Supports Union, List, Tuple, Set, frozenset, dataclass, Enum, Pydantic models.
179
+ """
180
+ origin = get_origin (expected_type )
181
+ args = get_args (expected_type )
184
182
185
- return obj
183
+ # Handle Union types
184
+ if origin is Union :
185
+ return any (self ._validate_type (obj , t ) for t in args )
186
186
187
- except Exception as e :
188
- raise DeserializationException (f"Return value decoding failed: { e } " ) from e
187
+ # Handle container types
188
+ if origin in (list , tuple , set , frozenset ):
189
+ if not isinstance (obj , origin ):
190
+ return False
191
+ if args :
192
+ return all (self ._validate_type (item , args [0 ]) for item in obj )
193
+ return True
194
+
195
+ # Dataclass
196
+ if hasattr (expected_type , "__dataclass_fields__" ):
197
+ return hasattr (obj , "__dataclass_fields__" ) and type (obj ) == expected_type
198
+
199
+ # Enum
200
+ import enum
201
+
202
+ if isinstance (expected_type , type ) and issubclass (expected_type , enum .Enum ):
203
+ return isinstance (obj , expected_type )
204
+
205
+ # Pydantic
206
+ try :
207
+ from pydantic import BaseModel
208
+
209
+ if issubclass (expected_type , BaseModel ):
210
+ return isinstance (obj , expected_type )
211
+ except Exception :
212
+ pass
213
+
214
+ # Plain types
215
+ return isinstance (obj , expected_type )
189
216
190
217
# Encoder/Decoder interface compatibility methods
191
218
def encoder (self ):
@@ -327,18 +354,10 @@ def _reconstruct_objects(self, data: Any) -> Any:
327
354
return data
328
355
329
356
if "$date" in data :
330
- from datetime import datetime
357
+ from datetime import datetime , timezone
331
358
332
- # Handle both ISO format with and without timezone
333
- date_str = data ["$date" ]
334
- if date_str .endswith ("Z" ):
335
- # Remove Z and treat as UTC
336
- date_str = date_str [:- 1 ] + "+00:00"
337
- try :
338
- return datetime .fromisoformat (date_str )
339
- except ValueError :
340
- # Fallback for older formats
341
- return datetime .fromisoformat (date_str .replace ("Z" , "+00:00" ))
359
+ dt = datetime .fromisoformat (data ["$date" ].replace ("Z" , "+00:00" ))
360
+ return dt .astimezone (timezone .utc )
342
361
343
362
elif "$uuid" in data :
344
363
from uuid import UUID
@@ -348,6 +367,9 @@ def _reconstruct_objects(self, data: Any) -> Any:
348
367
elif "$set" in data :
349
368
return set (self ._reconstruct_objects (item ) for item in data ["$set" ])
350
369
370
+ elif "$frozenset" in data :
371
+ return frozenset (self ._reconstruct_objects (item ) for item in data ["$frozenset" ])
372
+
351
373
elif "$tuple" in data :
352
374
return tuple (self ._reconstruct_objects (item ) for item in data ["$tuple" ])
353
375
0 commit comments