From d51e9c43a149448847d3598724b20941a8cebb43 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Fri, 4 Jul 2025 01:32:53 +0000 Subject: [PATCH 01/40] gsoc pull request --- src/dubbo/classes.py | 24 +- src/dubbo/client.py | 272 +++++++++++++-- src/dubbo/codec/__init__.py | 19 ++ src/dubbo/codec/dubbo_codec.py | 162 +++++++++ src/dubbo/codec/json_codec/__init__.py | 19 ++ .../codec/json_codec/json_codec_handler.py | 322 ++++++++++++++++++ src/dubbo/codec/json_codec/json_type.py | 274 +++++++++++++++ src/dubbo/proxy/handlers.py | 248 ++++++++------ 8 files changed, 1193 insertions(+), 147 deletions(-) create mode 100644 src/dubbo/codec/__init__.py create mode 100644 src/dubbo/codec/dubbo_codec.py create mode 100644 src/dubbo/codec/json_codec/__init__.py create mode 100644 src/dubbo/codec/json_codec/json_codec_handler.py create mode 100644 src/dubbo/codec/json_codec/json_type.py diff --git a/src/dubbo/classes.py b/src/dubbo/classes.py index 8d87299..e63aff0 100644 --- a/src/dubbo/classes.py +++ b/src/dubbo/classes.py @@ -13,10 +13,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import abc import threading -from typing import Any, Callable, Optional, Union - +from typing import Any, Callable, Optional, Union,Type +from abc import ABC, abstractmethod +from pydantic import BaseModel from dubbo.types import DeserializingFunction, RpcType, RpcTypes, SerializingFunction __all__ = [ @@ -244,3 +246,21 @@ class ReadWriteStream(ReadStream, WriteStream, abc.ABC): """ pass + + +class Codec(ABC): + def __init__(self, model_type: Type[BaseModel] = None, **kwargs): + self.model_type = model_type + + @abstractmethod + def encode(self, data: Any) -> bytes: + pass + + @abstractmethod + def decode(self, data: bytes) -> Any: + pass + +class CodecHelper: + @staticmethod + def get_class(): + return Codec diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 33e6264..89b85c1 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -13,8 +13,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import threading -from typing import Optional +from typing import Optional, Callable, List, Type, Union, Any from dubbo.bootstrap import Dubbo from dubbo.classes import MethodDescriptor @@ -31,6 +32,7 @@ SerializingFunction, ) from dubbo.url import URL +from dubbo.codec import DubboTransportService __all__ = ["Client"] @@ -84,68 +86,274 @@ def _initialize(self): def unary( self, - method_name: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.UNARY.value, + """ + Create unary RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + + # Validate + if interface is None and method_name is None: + raise ValueError("Either 'interface' or 'method_name' must be provided") + + # Determine the actual method name to call + actual_method_name = method_name or (interface.__name__ if interface else "unary") + + # Build method descriptor (automatic or manual) + if interface: + method_desc = DubboTransportService.create_method_descriptor( + func=interface, + method_name=actual_method_name, + parameter_types=params_types, + return_type=return_type, + interface=interface, ) + else: + # Manual mode fallback: use dummy function for descriptor creation + def dummy(): pass + + method_desc = DubboTransportService.create_method_descriptor( + func=dummy, + method_name=actual_method_name, + parameter_types=params_types or [], + return_type=return_type or Any, + ) + + # Determine serializers if not provided + if request_serializer and response_deserializer: + final_request_serializer = request_serializer + final_response_deserializer = response_deserializer + else: + # Use DubboTransportService to generate serialization functions + final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( + transport_type=codec or "json", + parameter_types=[p.annotation for p in method_desc.parameters], + return_type=method_desc.return_parameter.annotation, + ) + + # Create the proper MethodDescriptor for the RPC call + # This should match the structure expected by your RpcCallableFactory + rpc_method_descriptor = MethodDescriptor( + method_name=actual_method_name, + arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments + return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value + rpc_type=RpcTypes.UNARY.value, ) + # Create and return the RpcCallable + return self._callable(rpc_method_descriptor) + def client_stream( self, - method_name: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.CLIENT_STREAM.value, + """ + Create client streaming RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + + # Validate + if interface is None and method_name is None: + raise ValueError("Either 'interface' or 'method_name' must be provided") + + # Determine the actual method name to call + actual_method_name = method_name or (interface.__name__ if interface else "client_stream") + + # Build method descriptor (automatic or manual) + if interface: + method_desc = DubboTransportService.create_method_descriptor( + func=interface, + method_name=actual_method_name, + parameter_types=params_types, + return_type=return_type, + interface=interface, ) + else: + # Manual mode fallback: use dummy function for descriptor creation + def dummy(): pass + + method_desc = DubboTransportService.create_method_descriptor( + func=dummy, + method_name=actual_method_name, + parameter_types=params_types or [], + return_type=return_type or Any, + ) + + # Determine serializers if not provided + if request_serializer and response_deserializer: + final_request_serializer = request_serializer + final_response_deserializer = response_deserializer + else: + # Use DubboTransportService to generate serialization functions + final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( + transport_type=codec or "json", + parameter_types=[p.annotation for p in method_desc.parameters], + return_type=method_desc.return_parameter.annotation, + ) + + # Create the proper MethodDescriptor for the RPC call + # This should match the structure expected by your RpcCallableFactory + rpc_method_descriptor = MethodDescriptor( + method_name=actual_method_name, + arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments + return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value + rpc_type=RpcTypes.CLIENT_STREAM.value, ) + # Create and return the RpcCallable + return self._callable(rpc_method_descriptor) + def server_stream( self, - method_name: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.SERVER_STREAM.value, + """ + Create server streaming RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + + # Validate + if interface is None and method_name is None: + raise ValueError("Either 'interface' or 'method_name' must be provided") + + # Determine the actual method name to call + actual_method_name = method_name or (interface.__name__ if interface else "server_stream") + + # Build method descriptor (automatic or manual) + if interface: + method_desc = DubboTransportService.create_method_descriptor( + func=interface, + method_name=actual_method_name, + parameter_types=params_types, + return_type=return_type, + interface=interface, + ) + else: + # Manual mode fallback: use dummy function for descriptor creation + def dummy(): pass + + method_desc = DubboTransportService.create_method_descriptor( + func=dummy, + method_name=actual_method_name, + parameter_types=params_types or [], + return_type=return_type or Any, + ) + + # Determine serializers if not provided + if request_serializer and response_deserializer: + final_request_serializer = request_serializer + final_response_deserializer = response_deserializer + else: + # Use DubboTransportService to generate serialization functions + final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( + transport_type=codec or "json", + parameter_types=[p.annotation for p in method_desc.parameters], + return_type=method_desc.return_parameter.annotation, ) + + # Create the proper MethodDescriptor for the RPC call + # This should match the structure expected by your RpcCallableFactory + rpc_method_descriptor = MethodDescriptor( + method_name=actual_method_name, + arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments + return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value + rpc_type=RpcTypes.SERVER_STREAM.value, ) + # Create and return the RpcCallable + return self._callable(rpc_method_descriptor) + def bi_stream( self, - method_name: str, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: - # create method descriptor - return self._callable( - MethodDescriptor( - method_name=method_name, - arg_serialization=(request_serializer, None), - return_serialization=(None, response_deserializer), - rpc_type=RpcTypes.BI_STREAM.value, + """ + Create bidirectional streaming RPC call. + + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + + # Validate + if interface is None and method_name is None: + raise ValueError("Either 'interface' or 'method_name' must be provided") + + # Determine the actual method name to call + actual_method_name = method_name or (interface.__name__ if interface else "bi_stream") + + # Build method descriptor (automatic or manual) + if interface: + method_desc = DubboTransportService.create_method_descriptor( + func=interface, + method_name=actual_method_name, + parameter_types=params_types, + return_type=return_type, + interface=interface, + ) + else: + # Manual mode fallback: use dummy function for descriptor creation + def dummy(): pass + + method_desc = DubboTransportService.create_method_descriptor( + func=dummy, + method_name=actual_method_name, + parameter_types=params_types or [], + return_type=return_type or Any, ) + + # Determine serializers if not provided + if request_serializer and response_deserializer: + final_request_serializer = request_serializer + final_response_deserializer = response_deserializer + else: + # Use DubboTransportService to generate serialization functions + final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( + transport_type=codec or "json", + parameter_types=[p.annotation for p in method_desc.parameters], + return_type=method_desc.return_parameter.annotation, + ) + + + rpc_method_descriptor = MethodDescriptor( + method_name=actual_method_name, + arg_serialization=(final_request_serializer, None), + return_serialization=(None, final_response_deserializer), + rpc_type=RpcTypes.BI_STREAM.value, ) + # Create and return the RpcCallable + return self._callable(rpc_method_descriptor) + def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: """ - Generate a proxy for the given method + Generate a proxy for the given method. :param method_descriptor: The method descriptor. :return: The proxy. :rtype: RpcCallable @@ -160,4 +368,4 @@ def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: url.attributes[common_constants.METHOD_DESCRIPTOR_KEY] = method_descriptor # create proxy - return self._callable_factory.get_callable(self._invoker, url) + return self._callable_factory.get_callable(self._invoker, url) \ No newline at end of file diff --git a/src/dubbo/codec/__init__.py b/src/dubbo/codec/__init__.py new file mode 100644 index 0000000..dfd1b56 --- /dev/null +++ b/src/dubbo/codec/__init__.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .dubbo_codec import DubboTransportService + +__all__ = ['DubboTransportService'] \ No newline at end of file diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py new file mode 100644 index 0000000..b65cdda --- /dev/null +++ b/src/dubbo/codec/dubbo_codec.py @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Type, Optional, Callable, List, Dict +from dataclasses import dataclass +import inspect + +from dubbo.classes import CodecHelper +from dubbo.codec.json_codec import JsonTransportCodec, JsonTransportEncoder, JsonTransportDecoder + +@dataclass +class ParameterDescriptor: + """Detailed information about a method parameter""" + name: str + annotation: Any + is_required: bool = True + default_value: Any = None + + +@dataclass +class MethodDescriptor: + """Complete method descriptor with all necessary information""" + function: Callable + name: str + parameters: List[ParameterDescriptor] + return_parameter: ParameterDescriptor + documentation: Optional[str] = None + + +class DubboTransportService: + """Enhanced Dubbo transport service with robust type handling""" + + @staticmethod + def create_transport_codec(transport_type: str = 'json', parameter_types: List[Type] = None, + return_type: Type = None, **codec_options): + """Create transport codec with enhanced parameter structure""" + if transport_type == 'json': + return JsonTransportCodec( + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + else: + from dubbo.extension.extension_loader import ExtensionLoader + Codec = CodecHelper.get_class() + codec_class = ExtensionLoader().get_extension(Codec, transport_type) + return codec_class( + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + + @staticmethod + def create_encoder_decoder_pair(transport_type: str, parameter_types: List[Type] = None, + return_type: Type = None, **codec_options) -> tuple[any,any]: + """Create separate encoder and decoder instances""" + + if transport_type == 'json': + parameter_encoder = JsonTransportEncoder(parameter_types=parameter_types, **codec_options) + return_decoder = JsonTransportDecoder(target_type=return_type, **codec_options) + return parameter_encoder, return_decoder + else: + from dubbo.extension.extension_loader import ExtensionLoader + Codec = CodecHelper.get_class() + codec_class = ExtensionLoader().get_extension(Codec, transport_type) + + codec_instance = codec_class( + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + + return codec_instance.get_encoder(), codec_instance.get_decoder() + + @staticmethod + def create_serialization_functions(transport_type: str, parameter_types: List[Type] = None, + return_type: Type = None, **codec_options) -> tuple[Callable, Callable]: + """Create serializer and deserializer functions for RPC (backward compatibility)""" + + parameter_encoder, return_decoder = DubboTransportService.create_encoder_decoder_pair( + transport_type=transport_type, + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) + + def serialize_method_parameters(*args) -> bytes: + return parameter_encoder.encode(args) + + def deserialize_method_return(data: bytes): + return return_decoder.decode(data) + + print(type(serialize_method_parameters),type(deserialize_method_return)) + return serialize_method_parameters, deserialize_method_return + + @staticmethod + def create_method_descriptor(func: Callable, method_name: str = None, + parameter_types: List[Type] = None, return_type: Type = None, + interface: Callable = None) -> MethodDescriptor: + """Create a method descriptor from function and configuration""" + + name = method_name or (interface.__name__ if interface else func.__name__) + sig = inspect.signature(interface if interface else func) + + parameters = [] + resolved_parameter_types = parameter_types or [] + + for i, (param_name, param) in enumerate(sig.parameters.items()): + if param_name == 'self': + continue + + param_index = i - 1 if 'self' in sig.parameters else i + + if param_index < len(resolved_parameter_types): + param_type = resolved_parameter_types[param_index] + elif param.annotation != inspect.Parameter.empty: + param_type = param.annotation + else: + param_type = Any + + is_required = param.default == inspect.Parameter.empty + default_value = param.default if not is_required else None + + parameters.append(ParameterDescriptor( + name=param_name, + annotation=param_type, + is_required=is_required, + default_value=default_value + )) + + if return_type: + resolved_return_type = return_type + elif sig.return_annotation != inspect.Signature.empty: + resolved_return_type = sig.return_annotation + else: + resolved_return_type = Any + + return_parameter = ParameterDescriptor( + name="return_value", + annotation=resolved_return_type + ) + + return MethodDescriptor( + function=func, + name=name, + parameters=parameters, + return_parameter=return_parameter, + documentation=func.__doc__ + ) \ No newline at end of file diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py new file mode 100644 index 0000000..f66e01d --- /dev/null +++ b/src/dubbo/codec/json_codec/__init__.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .json_codec_handler import JsonTransportCodec,JsonTransportDecoder,JsonTransportEncoder + +__all__ = ["JsonTransportCodec", "JsonTransportDecoder", "JsonTransportEncoder"] diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py new file mode 100644 index 0000000..b7533c0 --- /dev/null +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -0,0 +1,322 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Type, List, Union, Dict, TypeVar, Protocol +from datetime import datetime, date, time +from decimal import Decimal +from pathlib import Path +from uuid import UUID +import json + +from .json_type import ( + TypeProviderFactory, SerializationState, + SerializationException, DeserializationException +) + +try: + import orjson + HAS_ORJSON = True +except ImportError: + HAS_ORJSON = False + +try: + import ujson + HAS_UJSON = True +except ImportError: + HAS_UJSON = False + +try: + from pydantic import BaseModel, create_model + HAS_PYDANTIC = True +except ImportError: + HAS_PYDANTIC = False + + +class EncodingFunction(Protocol): + def __call__(self, obj: Any) -> bytes: ... + + +class DecodingFunction(Protocol): + def __call__(self, data: bytes) -> Any: ... + + +ModelT = TypeVar('ModelT', bound=BaseModel) + + +class CustomJSONEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, datetime): + return { + "__datetime__": obj.isoformat(), + "__timezone__": str(obj.tzinfo) if obj.tzinfo else None + } + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, (set, frozenset)): + return { + "__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj) + } + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + else: + return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + + +class JsonTransportEncoder: + def __init__(self, parameter_types: List[Type] = None, maximum_depth: int = 100, + strict_validation: bool = True, **kwargs): + self.parameter_types = parameter_types or [] + self.maximum_depth = maximum_depth + self.strict_validation = strict_validation + self.type_registry = TypeProviderFactory.create_default_registry() + self.custom_encoder = CustomJSONEncoder(ensure_ascii=False, separators=(',', ':')) + self.single_parameter_mode = len(self.parameter_types) == 1 + self.multiple_parameter_mode = len(self.parameter_types) > 1 + if self.multiple_parameter_mode and HAS_PYDANTIC: + self.parameter_wrapper_model = self._create_parameter_wrapper_model() + + def _create_parameter_wrapper_model(self) -> Type[BaseModel]: + model_fields = {} + for i, param_type in enumerate(self.parameter_types): + model_fields[f"parameter_{i}"] = (param_type, ...) + return create_model('MethodParametersWrapper', **model_fields) + + def register_type_provider(self, provider) -> None: + self.type_registry.register_provider(provider) + + def encode(self, arguments: tuple) -> bytes: + try: + if not arguments: + return self._serialize_to_json_bytes([]) + + if self.single_parameter_mode: + parameter = arguments[0] + serialized_param = self._serialize_with_state(parameter) + if HAS_PYDANTIC and isinstance(parameter, BaseModel): + if hasattr(parameter, 'model_dump'): + return self._serialize_to_json_bytes(parameter.model_dump()) + return self._serialize_to_json_bytes(parameter.dict()) + elif isinstance(parameter, dict): + return self._serialize_to_json_bytes(serialized_param) + else: + return self._serialize_to_json_bytes(serialized_param) + + elif self.multiple_parameter_mode and HAS_PYDANTIC: + wrapper_data = {f"parameter_{i}": arg for i, arg in enumerate(arguments)} + wrapper_instance = self.parameter_wrapper_model(**wrapper_data) + return self._serialize_to_json_bytes(wrapper_instance.model_dump()) + + else: + serialized_args = [self._serialize_with_state(arg) for arg in arguments] + return self._serialize_to_json_bytes(serialized_args) + + except Exception as e: + raise SerializationException(f"Encoding failed: {e}") from e + + def _serialize_with_state(self, obj: Any) -> Any: + state = SerializationState(maximum_depth=self.maximum_depth) + return self._serialize_recursively(obj, state) + + def _serialize_recursively(self, obj: Any, state: SerializationState) -> Any: + if obj is None or isinstance(obj, (bool, int, float, str)): + return obj + if isinstance(obj, (list, tuple)): + state.validate_circular_reference(obj) + new_state = state.create_child_state(obj) + return [self._serialize_recursively(item, new_state) for item in obj] + elif isinstance(obj, dict): + state.validate_circular_reference(obj) + new_state = state.create_child_state(obj) + result = {} + for key, value in obj.items(): + if not isinstance(key, str): + if self.strict_validation: + raise SerializationException(f"Dictionary key must be string, got {type(key).__name__}") + key = str(key) + result[key] = self._serialize_recursively(value, new_state) + return result + + provider = self.type_registry.find_provider_for_object(obj) + if provider: + try: + serialized = provider.serialize_to_dict(obj, state) + return self._serialize_recursively(serialized, state) + except Exception as e: + if self.strict_validation: + raise SerializationException(f"Provider failed for {type(obj).__name__}: {e}") from e + return {"__serialization_error__": str(e), "__original_type__": type(obj).__name__} + else: + if self.strict_validation: + raise SerializationException(f"No provider for type {type(obj).__name__}") + return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + + def _serialize_to_json_bytes(self, obj: Any) -> bytes: + if HAS_ORJSON: + try: + return orjson.dumps(obj, default=self._orjson_default_handler) + except TypeError: + pass + if HAS_UJSON: + try: + return ujson.dumps(obj, ensure_ascii=False, default=self._ujson_default_handler).encode('utf-8') + except (TypeError, ValueError): + pass + return self.custom_encoder.encode(obj).encode('utf-8') + + def _orjson_default_handler(self, obj): + if isinstance(obj, datetime): + return { + "__datetime__": obj.isoformat(), + "__timezone__": str(obj.tzinfo) if obj.tzinfo else None + } + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, (set, frozenset)): + return { + "__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj) + } + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + else: + return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + + def _ujson_default_handler(self, obj): + return self._orjson_default_handler(obj) + + +class JsonTransportDecoder: + def __init__(self, target_type: Union[Type, List[Type]] = None, **kwargs): + self.target_type = target_type + if isinstance(target_type, list): + self.multiple_parameter_mode = len(target_type) > 1 + self.parameter_types = target_type + if self.multiple_parameter_mode and HAS_PYDANTIC: + self.parameter_wrapper_model = self._create_parameter_wrapper_model() + else: + self.multiple_parameter_mode = False + self.parameter_types = [target_type] if target_type else [] + + def _create_parameter_wrapper_model(self) -> Type[BaseModel]: + model_fields = {} + for i, param_type in enumerate(self.parameter_types): + model_fields[f"parameter_{i}"] = (param_type, ...) + return create_model('MethodParametersWrapper', **model_fields) + + def decode(self, data: bytes) -> Any: + try: + if not data: + return None + json_data = self._deserialize_from_json_bytes(data) + reconstructed_data = self._reconstruct_objects(json_data) + if not self.target_type: + return reconstructed_data + if isinstance(self.target_type, list): + if self.multiple_parameter_mode and HAS_PYDANTIC: + wrapper_instance = self.parameter_wrapper_model(**reconstructed_data) + return tuple(getattr(wrapper_instance, f"parameter_{i}") for i in range(len(self.parameter_types))) + else: + return self._decode_to_target_type(reconstructed_data, self.parameter_types[0]) + else: + return self._decode_to_target_type(reconstructed_data, self.target_type) + except Exception as e: + raise DeserializationException(f"Decoding failed: {e}") from e + + def _deserialize_from_json_bytes(self, data: bytes) -> Any: + if HAS_ORJSON: + try: + return orjson.loads(data) + except orjson.JSONDecodeError: + pass + if HAS_UJSON: + try: + return ujson.loads(data.decode('utf-8')) + except (ujson.JSONDecodeError, UnicodeDecodeError): + pass + return json.loads(data.decode('utf-8')) + + def _decode_to_target_type(self, json_data: Any, target_type: Type) -> Any: + if target_type in (str, int, float, bool, list, dict): + return target_type(json_data) + return json_data + + def _reconstruct_objects(self, data: Any) -> Any: + if not isinstance(data, dict): + if isinstance(data, list): + return [self._reconstruct_objects(item) for item in data] + return data + if "__datetime__" in data: + return datetime.fromisoformat(data["__datetime__"]) + elif "__date__" in data: + return date.fromisoformat(data["__date__"]) + elif "__time__" in data: + return time.fromisoformat(data["__time__"]) + elif "__decimal__" in data: + return Decimal(data["__decimal__"]) + elif "__set__" in data: + return set(self._reconstruct_objects(item) for item in data["__set__"]) + elif "__frozenset__" in data: + return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) + elif "__uuid__" in data: + return UUID(data["__uuid__"]) + elif "__path__" in data: + return Path(data["__path__"]) + elif "__dataclass__" in data or "__pydantic_model__" in data: + return data + else: + return {key: self._reconstruct_objects(value) for key, value in data.items()} + + +class JsonTransportCodec: + def __init__(self, parameter_types: List[Type] = None, return_type: Type = None, + maximum_depth: int = 100, strict_validation: bool = True, **kwargs): + self.parameter_types = parameter_types or [] + self.return_type = return_type + self.maximum_depth = maximum_depth + self.strict_validation = strict_validation + self._encoder = JsonTransportEncoder( + parameter_types=parameter_types, + maximum_depth=maximum_depth, + strict_validation=strict_validation, + **kwargs + ) + self._decoder = JsonTransportDecoder(target_type=return_type, **kwargs) + + def encode_parameters(self, *arguments) -> bytes: + return self._encoder.encode(arguments) + + def decode_return_value(self, data: bytes) -> Any: + return self._decoder.decode(data) + + def get_encoder(self) -> JsonTransportEncoder: + return self._encoder + + def get_decoder(self) -> JsonTransportDecoder: + return self._decoder + + def register_type_provider(self, provider) -> None: + self._encoder.register_type_provider(provider) diff --git a/src/dubbo/codec/json_codec/json_type.py b/src/dubbo/codec/json_codec/json_type.py new file mode 100644 index 0000000..3a6d840 --- /dev/null +++ b/src/dubbo/codec/json_codec/json_type.py @@ -0,0 +1,274 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import ( + Any, + Type, + Optional, + List, + Dict, + Set, + Protocol, + runtime_checkable, + Union, +) +from dataclasses import dataclass, fields, is_dataclass, asdict +from datetime import datetime, date, time +from decimal import Decimal +from collections import namedtuple +from pathlib import Path +from uuid import UUID +from enum import Enum +import weakref + +try: + from pydantic import BaseModel + + HAS_PYDANTIC = True +except ImportError: + HAS_PYDANTIC = False + + +class SerializationException(Exception): + """Exception raised during serialization""" + pass + +class DeserializationException(Exception): + """Exception raised during deserialization""" + pass + +class CircularReferenceException(SerializationException): + """Exception raised when circular references are detected""" + pass + +@dataclass(frozen=True) +class SerializationState: + _visited_objects: Set[int] = None + maximum_depth: int = 100 + current_depth: int = 0 + + def __post_init__(self): + if self._visited_objects is None: + object.__setattr__(self, "_visited_objects", set()) + + def validate_circular_reference(self, obj: Any) -> None: + object_id = id(obj) + if object_id in self._visited_objects: + raise CircularReferenceException( + f"Circular reference detected for {type(obj).__name__}" + ) + if self.current_depth >= self.maximum_depth: + raise SerializationException( + f"Maximum serialization depth ({self.maximum_depth}) exceeded" + ) + + def create_child_state(self, obj: Any) -> "SerializationState": + new_visited = self._visited_objects.copy() + new_visited.add(id(obj)) + return SerializationState( + _visited_objects=new_visited, + maximum_depth=self.maximum_depth, + current_depth=self.current_depth + 1, + ) + + +@runtime_checkable +class TypeSerializationProvider(Protocol): + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: ... + + def serialize_to_dict(self, obj: Any, state: SerializationState) -> Any: ... + + +class TypeProviderRegistry: + def __init__(self): + self._type_cache: Dict[type, Optional[TypeSerializationProvider]] = {} + self._providers: List[TypeSerializationProvider] = [] + self._weak_cache = weakref.WeakKeyDictionary() + + def register_provider(self, provider: TypeSerializationProvider) -> None: + self._providers.append(provider) + self._type_cache.clear() + self._weak_cache.clear() + + def find_provider_for_object(self, obj: Any) -> Optional[TypeSerializationProvider]: + obj_type = type(obj) + if obj_type in self._type_cache: + return self._type_cache[obj_type] + provider = None + for p in self._providers: + if p.can_serialize_type(obj, obj_type): + provider = p + break + self._type_cache[obj_type] = provider + return provider + + +class DateTimeSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type in (datetime, date, time) + + def serialize_to_dict( + self, obj: Union[datetime, date, time], state: SerializationState + ) -> Dict[str, str]: + if isinstance(obj, datetime): + return { + "__datetime__": obj.isoformat(), + "__timezone__": str(obj.tzinfo) if obj.tzinfo else None, + } + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + else: + return {"__time__": obj.isoformat()} + + +class DecimalSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type is Decimal + + def serialize_to_dict( + self, obj: Decimal, state: SerializationState + ) -> Dict[str, str]: + return {"__decimal__": str(obj)} + + +class CollectionSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type in (set, frozenset) + + def serialize_to_dict( + self, obj: Union[set, frozenset], state: SerializationState + ) -> Dict[str, Any]: + safe_items = [] + for item in obj: + if isinstance(item, (str, int, float, bool, type(None))): + safe_items.append(item) + else: + raise SerializationException( + f"Cannot serialize {type(item).__name__} in collection. " + f"Collections can only contain JSON-safe types (str, int, float, bool, None)" + ) + return { + "__frozenset__" if isinstance(obj, frozenset) else "__set__": safe_items + } + + +class DataclassSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return is_dataclass(obj) and not isinstance(obj, type) + + def serialize_to_dict(self, obj: Any, state: SerializationState) -> Dict[str, Any]: + state.validate_circular_reference(obj) + try: + field_data = asdict(obj) + return { + "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__field_data__": field_data, + } + except (TypeError, RecursionError): + field_data = {} + for field in fields(obj): + try: + field_data[field.name] = getattr(obj, field.name) + except Exception as e: + raise SerializationException( + f"Cannot serialize field '{field.name}': {e}" + ) + return { + "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__field_data__": field_data, + } + + +class NamedTupleSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return ( + hasattr(obj_type, "_fields") + and hasattr(obj, "_asdict") + and callable(obj._asdict) + ) + + def serialize_to_dict(self, obj: Any, state: SerializationState) -> Dict[str, Any]: + return { + "__namedtuple__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__tuple_data__": obj._asdict(), + } + + +class PydanticModelSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return HAS_PYDANTIC and isinstance(obj, BaseModel) + + def serialize_to_dict( + self, obj: BaseModel, state: SerializationState + ) -> Dict[str, Any]: + state.validate_circular_reference(obj) + if hasattr(obj, "model_dump"): + model_data = obj.model_dump() + else: + model_data = obj.dict() + return { + "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__model_data__": model_data, + } + + +class SimpleTypeSerializationProvider: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type is UUID or isinstance(obj, (Path, Enum)) + + def serialize_to_dict( + self, obj: Union[UUID, Path, Enum], state: SerializationState + ) -> Dict[str, str]: + if isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + else: + return { + "__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__enum_value__": obj.value, + } + + +class TypeProviderFactory: + @staticmethod + def create_default_registry() -> TypeProviderRegistry: + registry = TypeProviderRegistry() + default_providers = [ + DateTimeSerializationProvider(), + DecimalSerializationProvider(), + CollectionSerializationProvider(), + DataclassSerializationProvider(), + NamedTupleSerializationProvider(), + PydanticModelSerializationProvider(), + SimpleTypeSerializationProvider(), + ] + for provider in default_providers: + registry.register_provider(provider) + return registry + + @staticmethod + def create_minimal_registry() -> TypeProviderRegistry: + registry = TypeProviderRegistry() + essential_providers = [ + DateTimeSerializationProvider(), + DecimalSerializationProvider(), + SimpleTypeSerializationProvider(), + ] + for provider in essential_providers: + registry.register_provider(provider) + return registry diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index 8c89663..592d70f 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional - +import inspect +from typing import Callable, Optional, List, Type, Any, get_type_hints from dubbo.classes import MethodDescriptor +from dubbo.codec import DubboTransportService from dubbo.types import ( DeserializingFunction, RpcTypes, @@ -27,57 +28,96 @@ class RpcMethodHandler: - """ - Rpc method handler - """ - __slots__ = ["_method_descriptor"] def __init__(self, method_descriptor: MethodDescriptor): - """ - Initialize the RpcMethodHandler - :param method_descriptor: the method descriptor. - :type method_descriptor: MethodDescriptor - """ self._method_descriptor = method_descriptor @property def method_descriptor(self) -> MethodDescriptor: - """ - Get the method descriptor - :return: the method descriptor - :rtype: MethodDescriptor - """ return self._method_descriptor + @staticmethod + def get_codec(**kwargs) -> tuple: + return DubboTransportService.create_serialization_functions(**kwargs) + + @classmethod + def _infer_types_from_method(cls, method: Callable) -> tuple: + try: + type_hints = get_type_hints(method) + sig = inspect.signature(method) + method_name = method.__name__ + params = list(sig.parameters.values()) + if params and params[0].name == "self": + params = params[1:] + + params_types = [type_hints.get(p.name, Any) for p in params] + return_type = type_hints.get("return", Any) + return method_name, params_types, return_type + except Exception: + return method.__name__, [Any], Any + + @classmethod + def _create_method_descriptor( + cls, + method: Callable, + method_name: str, + params_types: List[Type], + return_type: Type, + rpc_type: str, + codec: Optional[str] = None, + param_encoder: Optional[DeserializingFunction] = None, + return_decoder: Optional[SerializingFunction] = None, + **kwargs, + ) -> MethodDescriptor: + if param_encoder is None or return_decoder is None: + codec_kwargs = { + "transport_type": codec or "json", + "parameter_types": params_types, + "return_type": return_type, + **kwargs, + } + serializer, deserializer = cls.get_codec(**codec_kwargs) + request_deserializer = param_encoder or deserializer + response_serializer = return_decoder or serializer + + return MethodDescriptor( + callable_method=method, + method_name=method_name or method.__name__, + arg_serialization=(None, request_deserializer), + return_serialization=(response_serializer, None), + rpc_type=rpc_type + ) + @classmethod def unary( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, + **kwargs, ) -> "RpcMethodHandler": - """ - Create a unary method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the unary method handler. - :rtype: RpcMethodHandler - """ + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.UNARY.value, + codec=codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) @@ -86,29 +126,30 @@ def client_stream( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, - ): - """ - Create a client stream method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the client stream method handler. - :rtype: RpcMethodHandler - """ + **kwargs, + ) -> "RpcMethodHandler": + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + resolved_codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.CLIENT_STREAM.value, + codec=resolved_codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) @@ -117,29 +158,30 @@ def server_stream( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, - ): - """ - Create a server stream method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the server stream method handler. - :rtype: RpcMethodHandler - """ + **kwargs, + ) -> "RpcMethodHandler": + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + resolved_codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.SERVER_STREAM.value, + codec=resolved_codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) @@ -148,48 +190,38 @@ def bi_stream( cls, method: Callable, method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, - ): - """ - Create a bidi stream method handler - :param method: the method. - :type method: Callable - :param method_name: the method name. If not provided, the method name will be used. - :type method_name: Optional[str] - :param request_deserializer: the request deserializer. - :type request_deserializer: Optional[DeserializingFunction] - :param response_serializer: the response serializer. - :type response_serializer: Optional[SerializingFunction] - :return: the bidi stream method handler. - :rtype: RpcMethodHandler - """ + **kwargs, + ) -> "RpcMethodHandler": + inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) + resolved_method_name = method_name or inferred_name + resolved_param_types = params_types or inferred_param_types + resolved_return_type = return_type or inferred_return_type + resolved_codec = codec or "json" + return cls( - MethodDescriptor( - callable_method=method, - method_name=method_name or method.__name__, - arg_serialization=(None, request_deserializer), - return_serialization=(response_serializer, None), + cls._create_method_descriptor( + method=method, + method_name=resolved_method_name, + params_types=resolved_param_types, + return_type=resolved_return_type, rpc_type=RpcTypes.BI_STREAM.value, + codec=resolved_codec, + request_deserializer=request_deserializer, + response_serializer=response_serializer, + **kwargs, ) ) class RpcServiceHandler: - """ - Rpc service handler - """ - __slots__ = ["_service_name", "_method_handlers"] def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): - """ - Initialize the RpcServiceHandler - :param service_name: the name of the service. - :type service_name: str - :param method_handlers: the method handlers. - :type method_handlers: List[RpcMethodHandler] - """ self._service_name = service_name self._method_handlers: dict[str, RpcMethodHandler] = {} @@ -199,18 +231,8 @@ def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): @property def service_name(self) -> str: - """ - Get the service name - :return: the service name - :rtype: str - """ return self._service_name @property def method_handlers(self) -> dict[str, RpcMethodHandler]: - """ - Get the method handlers - :return: the method handlers - :rtype: Dict[str, RpcMethodHandler] - """ - return self._method_handlers + return self._method_handlers \ No newline at end of file From a785c345d942a997fe368f929a2c7a11f78b4adf Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Mon, 14 Jul 2025 18:46:10 +0000 Subject: [PATCH 02/40] fix the duplicate code issue in client.py --- src/dubbo/client.py | 214 +++++++++++++------------------------------- 1 file changed, 60 insertions(+), 154 deletions(-) diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 89b85c1..4ad48b9 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -84,8 +84,9 @@ def _initialize(self): self._initialized = True - def unary( + def _create_rpc_callable( self, + rpc_type: str, interface: Optional[Callable] = None, method_name: Optional[str] = None, params_types: Optional[List[Type]] = None, @@ -93,19 +94,17 @@ def unary( codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, + default_method_name: str = "rpc_call", ) -> RpcCallable: """ - Create unary RPC call. - - Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + Create RPC callable with the specified type. """ - # Validate if interface is None and method_name is None: raise ValueError("Either 'interface' or 'method_name' must be provided") # Determine the actual method name to call - actual_method_name = method_name or (interface.__name__ if interface else "unary") + actual_method_name = method_name or (interface.__name__ if interface else default_method_name) # Build method descriptor (automatic or manual) if interface: @@ -140,18 +139,17 @@ def dummy(): pass ) # Create the proper MethodDescriptor for the RPC call - # This should match the structure expected by your RpcCallableFactory rpc_method_descriptor = MethodDescriptor( method_name=actual_method_name, arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value - rpc_type=RpcTypes.UNARY.value, + rpc_type=rpc_type, ) # Create and return the RpcCallable return self._callable(rpc_method_descriptor) - def client_stream( + def unary( self, interface: Optional[Callable] = None, method_name: Optional[str] = None, @@ -162,62 +160,49 @@ def client_stream( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: """ - Create client streaming RPC call. + Create unary RPC call. Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). """ + return self._create_rpc_callable( + rpc_type=RpcTypes.UNARY.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="unary", + ) - # Validate - if interface is None and method_name is None: - raise ValueError("Either 'interface' or 'method_name' must be provided") - - # Determine the actual method name to call - actual_method_name = method_name or (interface.__name__ if interface else "client_stream") - - # Build method descriptor (automatic or manual) - if interface: - method_desc = DubboTransportService.create_method_descriptor( - func=interface, - method_name=actual_method_name, - parameter_types=params_types, - return_type=return_type, - interface=interface, - ) - else: - # Manual mode fallback: use dummy function for descriptor creation - def dummy(): pass - - method_desc = DubboTransportService.create_method_descriptor( - func=dummy, - method_name=actual_method_name, - parameter_types=params_types or [], - return_type=return_type or Any, - ) - - # Determine serializers if not provided - if request_serializer and response_deserializer: - final_request_serializer = request_serializer - final_response_deserializer = response_deserializer - else: - # Use DubboTransportService to generate serialization functions - final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( - transport_type=codec or "json", - parameter_types=[p.annotation for p in method_desc.parameters], - return_type=method_desc.return_parameter.annotation, - ) + def client_stream( + self, + interface: Optional[Callable] = None, + method_name: Optional[str] = None, + params_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + codec: Optional[str] = None, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: + """ + Create client streaming RPC call. - # Create the proper MethodDescriptor for the RPC call - # This should match the structure expected by your RpcCallableFactory - rpc_method_descriptor = MethodDescriptor( - method_name=actual_method_name, - arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments - return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value + Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). + """ + return self._create_rpc_callable( rpc_type=RpcTypes.CLIENT_STREAM.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="client_stream", ) - # Create and return the RpcCallable - return self._callable(rpc_method_descriptor) - def server_stream( self, interface: Optional[Callable] = None, @@ -233,58 +218,18 @@ def server_stream( Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). """ - - # Validate - if interface is None and method_name is None: - raise ValueError("Either 'interface' or 'method_name' must be provided") - - # Determine the actual method name to call - actual_method_name = method_name or (interface.__name__ if interface else "server_stream") - - # Build method descriptor (automatic or manual) - if interface: - method_desc = DubboTransportService.create_method_descriptor( - func=interface, - method_name=actual_method_name, - parameter_types=params_types, - return_type=return_type, - interface=interface, - ) - else: - # Manual mode fallback: use dummy function for descriptor creation - def dummy(): pass - - method_desc = DubboTransportService.create_method_descriptor( - func=dummy, - method_name=actual_method_name, - parameter_types=params_types or [], - return_type=return_type or Any, - ) - - # Determine serializers if not provided - if request_serializer and response_deserializer: - final_request_serializer = request_serializer - final_response_deserializer = response_deserializer - else: - # Use DubboTransportService to generate serialization functions - final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( - transport_type=codec or "json", - parameter_types=[p.annotation for p in method_desc.parameters], - return_type=method_desc.return_parameter.annotation, - ) - - # Create the proper MethodDescriptor for the RPC call - # This should match the structure expected by your RpcCallableFactory - rpc_method_descriptor = MethodDescriptor( - method_name=actual_method_name, - arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments - return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value + return self._create_rpc_callable( rpc_type=RpcTypes.SERVER_STREAM.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="server_stream", ) - # Create and return the RpcCallable - return self._callable(rpc_method_descriptor) - def bi_stream( self, interface: Optional[Callable] = None, @@ -300,57 +245,18 @@ def bi_stream( Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). """ - - # Validate - if interface is None and method_name is None: - raise ValueError("Either 'interface' or 'method_name' must be provided") - - # Determine the actual method name to call - actual_method_name = method_name or (interface.__name__ if interface else "bi_stream") - - # Build method descriptor (automatic or manual) - if interface: - method_desc = DubboTransportService.create_method_descriptor( - func=interface, - method_name=actual_method_name, - parameter_types=params_types, - return_type=return_type, - interface=interface, - ) - else: - # Manual mode fallback: use dummy function for descriptor creation - def dummy(): pass - - method_desc = DubboTransportService.create_method_descriptor( - func=dummy, - method_name=actual_method_name, - parameter_types=params_types or [], - return_type=return_type or Any, - ) - - # Determine serializers if not provided - if request_serializer and response_deserializer: - final_request_serializer = request_serializer - final_response_deserializer = response_deserializer - else: - # Use DubboTransportService to generate serialization functions - final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( - transport_type=codec or "json", - parameter_types=[p.annotation for p in method_desc.parameters], - return_type=method_desc.return_parameter.annotation, - ) - - - rpc_method_descriptor = MethodDescriptor( - method_name=actual_method_name, - arg_serialization=(final_request_serializer, None), - return_serialization=(None, final_response_deserializer), + return self._create_rpc_callable( rpc_type=RpcTypes.BI_STREAM.value, + interface=interface, + method_name=method_name, + params_types=params_types, + return_type=return_type, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, + default_method_name="bi_stream", ) - # Create and return the RpcCallable - return self._callable(rpc_method_descriptor) - def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: """ Generate a proxy for the given method. From 95321d77bada6b9aa427bdd997f7926b0b19e12d Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 17 Jul 2025 21:09:30 +0000 Subject: [PATCH 03/40] fixed with changes --- src/dubbo/proxy/handlers.py | 147 +++++++++++++++++++++++++++++++----- src/dubbo/server.py | 2 +- 2 files changed, 129 insertions(+), 20 deletions(-) diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index 592d70f..8539d63 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -1,19 +1,3 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import inspect from typing import Callable, Optional, List, Type, Any, get_type_hints from dubbo.classes import MethodDescriptor @@ -27,33 +11,75 @@ __all__ = ["RpcMethodHandler", "RpcServiceHandler"] +class RpcMethodConfigurationError(Exception): + """ + Raised when RPC method is configured incorrectly. + """ + pass + + class RpcMethodHandler: + """ + Rpc method handler that wraps metadata and serialization logic for a callable. + """ + __slots__ = ["_method_descriptor"] def __init__(self, method_descriptor: MethodDescriptor): + """ + Initialize the RpcMethodHandler + :param method_descriptor: the method descriptor. + :type method_descriptor: MethodDescriptor + """ self._method_descriptor = method_descriptor @property def method_descriptor(self) -> MethodDescriptor: + """ + Get the method descriptor + :return: the method descriptor + :rtype: MethodDescriptor + """ return self._method_descriptor @staticmethod def get_codec(**kwargs) -> tuple: + """ + Get the serialization and deserialization functions based on codec + :param kwargs: codec settings like transport_type, parameter_types, return_type + :return: serializer and deserializer functions + :rtype: Tuple[SerializingFunction, DeserializingFunction] + """ return DubboTransportService.create_serialization_functions(**kwargs) @classmethod def _infer_types_from_method(cls, method: Callable) -> tuple: + """ + Infer method name, parameter types, and return type from a callable + :param method: the method to analyze + :type method: Callable + :return: tuple of method name, parameter types, return type + :rtype: Tuple[str, List[Type], Type] + """ try: type_hints = get_type_hints(method) sig = inspect.signature(method) method_name = method.__name__ params = list(sig.parameters.values()) + if params and params[0].name == "self": - params = params[1:] + raise RpcMethodConfigurationError( + f"Method '{method_name}' appears to be an unbound method with 'self' parameter. " + "RPC methods should be bound methods (e.g., instance.method) or standalone functions. " + "If you're registering a class method, ensure you pass a bound method: " + "RpcMethodHandler.unary(instance.method) not RpcMethodHandler.unary(Class.method)" + ) params_types = [type_hints.get(p.name, Any) for p in params] return_type = type_hints.get("return", Any) return method_name, params_types, return_type + except RpcMethodConfigurationError: + raise except Exception: return method.__name__, [Any], Any @@ -70,6 +96,20 @@ def _create_method_descriptor( return_decoder: Optional[SerializingFunction] = None, **kwargs, ) -> MethodDescriptor: + """ + Create a MethodDescriptor with serialization configuration + :param method: the actual function/method + :param method_name: RPC method name + :param params_types: parameter type hints + :param return_type: return type hint + :param rpc_type: type of RPC (unary, stream, etc.) + :param codec: serialization codec (json, pb, etc.) + :param param_encoder: deserialization function + :param return_decoder: serialization function + :param kwargs: additional codec args + :return: MethodDescriptor instance + :rtype: MethodDescriptor + """ if param_encoder is None or return_decoder is None: codec_kwargs = { "transport_type": codec or "json", @@ -101,6 +141,18 @@ def unary( response_serializer: Optional[SerializingFunction] = None, **kwargs, ) -> "RpcMethodHandler": + """ + Register a unary RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance + :rtype: RpcMethodHandler + """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name resolved_param_types = params_types or inferred_param_types @@ -133,6 +185,18 @@ def client_stream( response_serializer: Optional[SerializingFunction] = None, **kwargs, ) -> "RpcMethodHandler": + """ + Register a client-streaming RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance + :rtype: RpcMethodHandler + """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name resolved_param_types = params_types or inferred_param_types @@ -165,6 +229,18 @@ def server_stream( response_serializer: Optional[SerializingFunction] = None, **kwargs, ) -> "RpcMethodHandler": + """ + Register a server-streaming RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance + :rtype: RpcMethodHandler + """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name resolved_param_types = params_types or inferred_param_types @@ -197,6 +273,18 @@ def bi_stream( response_serializer: Optional[SerializingFunction] = None, **kwargs, ) -> "RpcMethodHandler": + """ + Register a bidirectional streaming RPC method handler + :param method: the callable function + :param method_name: RPC method name + :param params_types: input types + :param return_type: output type + :param codec: serialization codec + :param request_deserializer: custom deserializer + :param response_serializer: custom serializer + :return: RpcMethodHandler instance + :rtype: RpcMethodHandler + """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name resolved_param_types = params_types or inferred_param_types @@ -219,9 +307,20 @@ def bi_stream( class RpcServiceHandler: + """ + Rpc service handler that maps method names to their corresponding RpcMethodHandler. + """ + __slots__ = ["_service_name", "_method_handlers"] - def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): + def __init__(self, service_name: str, method_handlers: List[RpcMethodHandler]): + """ + Initialize the RpcServiceHandler + :param service_name: the name of the service. + :type service_name: str + :param method_handlers: list of RpcMethodHandler instances + :type method_handlers: List[RpcMethodHandler] + """ self._service_name = service_name self._method_handlers: dict[str, RpcMethodHandler] = {} @@ -231,8 +330,18 @@ def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): @property def service_name(self) -> str: + """ + Get the service name + :return: the service name + :rtype: str + """ return self._service_name @property def method_handlers(self) -> dict[str, RpcMethodHandler]: - return self._method_handlers \ No newline at end of file + """ + Get the registered RPC method handlers + :return: mapping of method names to handlers + :rtype: Dict[str, RpcMethodHandler] + """ + return self._method_handlers diff --git a/src/dubbo/server.py b/src/dubbo/server.py index b7c4dee..52a3507 100644 --- a/src/dubbo/server.py +++ b/src/dubbo/server.py @@ -81,4 +81,4 @@ def start(self): self._protocol.export(self._url) - self._exported = True + self._exported = True \ No newline at end of file From 109a1e4c41b9018ee7bbe3236d5a8206371852bb Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 17 Jul 2025 21:13:55 +0000 Subject: [PATCH 04/40] Resolve the error for the high dependency on pydantic --- src/dubbo/classes.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/dubbo/classes.py b/src/dubbo/classes.py index e63aff0..d6ba2ab 100644 --- a/src/dubbo/classes.py +++ b/src/dubbo/classes.py @@ -18,7 +18,6 @@ import threading from typing import Any, Callable, Optional, Union,Type from abc import ABC, abstractmethod -from pydantic import BaseModel from dubbo.types import DeserializingFunction, RpcType, RpcTypes, SerializingFunction __all__ = [ @@ -249,7 +248,7 @@ class ReadWriteStream(ReadStream, WriteStream, abc.ABC): class Codec(ABC): - def __init__(self, model_type: Type[BaseModel] = None, **kwargs): + def __init__(self, model_type: Optional[Type[Any]] = None, **kwargs): self.model_type = model_type @abstractmethod From 57b012990341588f12b245ab6982d647fb227232 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 23 Aug 2025 20:10:09 +0000 Subject: [PATCH 05/40] completed the protbuf implementation --- src/dubbo/client.py | 7 +- src/dubbo/codec/dubbo_codec.py | 1 + src/dubbo/codec/protobuf_codec/__init__.py | 21 ++ .../protobuf_codec/protobuf_codec_handler.py | 305 ++++++++++++++++++ src/dubbo/extension/registries.py | 94 +++++- 5 files changed, 420 insertions(+), 8 deletions(-) create mode 100644 src/dubbo/codec/protobuf_codec/__init__.py create mode 100644 src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 4ad48b9..7dc215a 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -132,11 +132,8 @@ def dummy(): pass final_response_deserializer = response_deserializer else: # Use DubboTransportService to generate serialization functions - final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions( - transport_type=codec or "json", - parameter_types=[p.annotation for p in method_desc.parameters], - return_type=method_desc.return_parameter.annotation, - ) + final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions(codec, parameter_types=params_types, return_type=return_type) + print("final",codec, final_request_serializer, final_response_deserializer) # Create the proper MethodDescriptor for the RPC call rpc_method_descriptor = MethodDescriptor( diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index b65cdda..640d397 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -82,6 +82,7 @@ def create_encoder_decoder_pair(transport_type: str, parameter_types: List[Type] return_type=return_type, **codec_options ) + print("codec_instance", codec_instance.get_encoder(), codec_instance.get_decoder()) return codec_instance.get_encoder(), codec_instance.get_decoder() diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py new file mode 100644 index 0000000..4ddd635 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .protobuf_codec_handler import ProtobufTransportCodec, ProtobufTransportEncoder, ProtobufTransportDecoder + +__all__ = [ + "ProtobufTransportCodec", "ProtobufTransportEncoder", "ProtobufTransportDecoder" +] \ No newline at end of file diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py new file mode 100644 index 0000000..8c75062 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py @@ -0,0 +1,305 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Type, Protocol, Optional +from abc import ABC, abstractmethod +import json +from dataclasses import dataclass + +# Betterproto imports +try: + import betterproto + HAS_BETTERPROTO = True +except ImportError: + HAS_BETTERPROTO = False + +try: + from pydantic import BaseModel + HAS_PYDANTIC = True +except ImportError: + HAS_PYDANTIC = False + +# Reuse your existing JSON type system +from dubbo.codec.json_codec.json_type import ( + TypeProviderFactory, SerializationState, + SerializationException, DeserializationException +) + + +class ProtobufEncodingFunction(Protocol): + def __call__(self, obj: Any) -> bytes: ... + + +class ProtobufDecodingFunction(Protocol): + def __call__(self, data: bytes) -> Any: ... + + +@dataclass +class ProtobufMethodDescriptor: + """Protobuf-specific method descriptor for single parameter""" + parameter_type: Type + return_type: Type + protobuf_message_type: Optional[Type] = None + use_json_fallback: bool = False + + +class ProtobufTypeHandler: + """Handles type conversion between Python types and Betterproto""" + + @staticmethod + def is_betterproto_message(obj_type: Type) -> bool: + """Check if type is a betterproto message class""" + if not HAS_BETTERPROTO: + return False + try: + return (hasattr(obj_type, '__dataclass_fields__') and + issubclass(obj_type, betterproto.Message)) + except (TypeError, AttributeError): + return False + + @staticmethod + def is_betterproto_message_instance(obj: Any) -> bool: + """Check if object is a betterproto message instance""" + if not HAS_BETTERPROTO: + return False + try: + return isinstance(obj, betterproto.Message) + except: + return False + + @staticmethod + def is_protobuf_compatible(obj_type: Type) -> bool: + """Check if type can be handled by protobuf""" + return (obj_type in (str, int, float, bool, bytes) or + ProtobufTypeHandler.is_betterproto_message(obj_type)) + + @staticmethod + def needs_json_fallback(parameter_type: Type) -> bool: + """Check if we need JSON fallback for this type""" + return not ProtobufTypeHandler.is_protobuf_compatible(parameter_type) + + +class ProtobufTransportEncoder: + """Protobuf encoder for single parameters using betterproto""" + + def __init__(self, parameter_type: Type = None, **kwargs): + if not HAS_BETTERPROTO: + raise ImportError("betterproto library is required for ProtobufTransportEncoder") + + self.parameter_type = parameter_type + + self.descriptor = ProtobufMethodDescriptor( + parameter_type=parameter_type, + return_type=None, + use_json_fallback=ProtobufTypeHandler.needs_json_fallback(parameter_type) if parameter_type else False + ) + + if self.descriptor.use_json_fallback: + from dubbo.codec.json_codec.json_codec_handler import JsonTransportEncoder + self.json_fallback_encoder = JsonTransportEncoder([parameter_type], **kwargs) + + def encode(self, parameter: Any) -> bytes: + """Encode single parameter to bytes""" + try: + if parameter is None: + return b'' + + # Handle case where parameter is a tuple (common in RPC calls) + if isinstance(parameter, tuple): + if len(parameter) == 0: + return b'' + elif len(parameter) == 1: + return self._encode_single_parameter(parameter[0]) + else: + raise SerializationException(f"Multiple parameters not supported. Got tuple with {len(parameter)} elements, expected 1.") + + return self._encode_single_parameter(parameter) + + except Exception as e: + raise SerializationException(f"Protobuf encoding failed: {e}") from e + + def _encode_single_parameter(self, parameter: Any) -> bytes: + """Encode a single parameter using betterproto""" + # If it's already a betterproto message instance, serialize it + if ProtobufTypeHandler.is_betterproto_message_instance(parameter): + return bytes(parameter) + + # If we have type info and it's a betterproto message type + if self.parameter_type and ProtobufTypeHandler.is_betterproto_message(self.parameter_type): + if isinstance(parameter, self.parameter_type): + return bytes(parameter) + elif isinstance(parameter, dict): + # Convert dict to betterproto message + try: + message = self.parameter_type().from_dict(parameter) + return bytes(message) + except Exception as e: + raise SerializationException(f"Cannot convert dict to {self.parameter_type}: {e}") + else: + raise SerializationException(f"Cannot convert {type(parameter)} to {self.parameter_type}") + + # Handle primitive types by wrapping in a simple message + if isinstance(parameter, (str, int, float, bool, bytes)): + return self._encode_primitive(parameter) + + # Use JSON fallback if configured + if self.descriptor.use_json_fallback: + json_data = self.json_fallback_encoder.encode((parameter,)) + return json_data + + raise SerializationException(f"Cannot encode {type(parameter)} as protobuf") + + def _encode_primitive(self, value: Any) -> bytes: + """Encode primitive values by wrapping them in a simple structure""" + # For primitives, we'll use JSON encoding wrapped in bytes + # This is a simplified approach - in a real implementation you might + # want to define a wrapper protobuf message for primitives + try: + json_str = json.dumps({"value": value, "type": type(value).__name__}) + return json_str.encode('utf-8') + except Exception as e: + raise SerializationException(f"Failed to encode primitive {value}: {e}") + + +class ProtobufTransportDecoder: + """Protobuf decoder for single parameters using betterproto""" + + def __init__(self, target_type: Type = None, **kwargs): + if not HAS_BETTERPROTO: + raise ImportError("betterproto library is required for ProtobufTransportDecoder") + + self.target_type = target_type + self.use_json_fallback = ProtobufTypeHandler.needs_json_fallback(target_type) if target_type else False + + if self.use_json_fallback: + from dubbo.codec.json_codec.json_codec_handler import JsonTransportDecoder + self.json_fallback_decoder = JsonTransportDecoder(target_type, **kwargs) + + def decode(self, data: bytes) -> Any: + """Decode bytes to single parameter""" + try: + if not data: + return None + + if not self.target_type: + return self._decode_without_type_info(data) + + return self._decode_single_parameter(data, self.target_type) + + except Exception as e: + raise DeserializationException(f"Protobuf decoding failed: {e}") from e + + def _decode_single_parameter(self, data: bytes, target_type: Type) -> Any: + """Decode single parameter using betterproto""" + if ProtobufTypeHandler.is_betterproto_message(target_type): + try: + # Use betterproto's parsing + message_instance = target_type().parse(data) + return message_instance + except Exception as e: + if self.use_json_fallback: + return self.json_fallback_decoder.decode(data) + raise DeserializationException(f"Failed to parse betterproto message: {e}") + + # Handle primitives + elif target_type in (str, int, float, bool, bytes): + return self._decode_primitive(data, target_type) + + # Use JSON fallback + elif self.use_json_fallback: + return self.json_fallback_decoder.decode(data) + + else: + raise DeserializationException(f"Cannot decode to {target_type} from protobuf") + + def _decode_primitive(self, data: bytes, target_type: Type) -> Any: + """Decode primitive values from their wrapped format""" + try: + json_str = data.decode('utf-8') + parsed = json.loads(json_str) + value = parsed.get("value") + + # Convert to target type if needed + if target_type == str: + return str(value) + elif target_type == int: + return int(value) + elif target_type == float: + return float(value) + elif target_type == bool: + return bool(value) + elif target_type == bytes: + return bytes(value) if isinstance(value, (list, bytes)) else str(value).encode() + else: + return value + + except Exception as e: + raise DeserializationException(f"Failed to decode primitive: {e}") + + def _decode_without_type_info(self, data: bytes) -> Any: + """Decode without type information - try JSON first""" + try: + return json.loads(data.decode('utf-8')) + except: + return data + + +class ProtobufTransportCodec: + """Main protobuf codec class for single parameters using betterproto""" + + def __init__(self, parameter_type: Type = None, return_type: Type = None, **kwargs): + if not HAS_BETTERPROTO: + raise ImportError("betterproto library is required for ProtobufTransportCodec") + + self.parameter_type = parameter_type + self.return_type = return_type + + self._encoder = ProtobufTransportEncoder( + parameter_type=parameter_type, + **kwargs + ) + self._decoder = ProtobufTransportDecoder( + target_type=return_type, + **kwargs + ) + + def encode_parameter(self, argument: Any) -> bytes: + """Encode single parameter""" + return self._encoder.encode(argument) + + def encode_parameters(self, arguments: tuple) -> bytes: + """Legacy method to handle tuple of arguments (for backward compatibility)""" + if not arguments: + return b'' + if len(arguments) == 1: + return self._encoder.encode(arguments[0]) + else: + raise SerializationException(f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1.") + + def decode_return_value(self, data: bytes) -> Any: + """Decode return value""" + return self._decoder.decode(data) + + def get_encoder(self) -> ProtobufTransportEncoder: + return self._encoder + + def get_decoder(self) -> ProtobufTransportDecoder: + return self._decoder + + +def create_protobuf_codec(**kwargs) -> ProtobufTransportCodec: + """Factory function to create protobuf codec""" + return ProtobufTransportCodec(**kwargs) \ No newline at end of file diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index cf23ae7..d188d9c 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -14,6 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib +from typing import Any +from dataclasses import dataclass + +from dubbo.classes import SingletonBase +from dubbo.extension import registries as registries_module + +# Import all the required interface classes from dataclasses import dataclass from typing import Any @@ -22,6 +30,78 @@ from dubbo.protocol import Protocol from dubbo.registry import RegistryFactory from dubbo.remoting import Transporter +from dubbo.classes import Codec + + +class ExtensionError(Exception): + """ + Extension error. + """ + + def __init__(self, message: str): + """ + Initialize the extension error. + :param message: The error message. + :type message: str + """ + super().__init__(message) + + +class ExtensionLoader(SingletonBase): + """ + Singleton class for loading extension implementations. + """ + + def __init__(self): + """ + Initialize the extension loader. + + Load all the registries from the registries module. + """ + if not hasattr(self, "_initialized"): # Ensure __init__ runs only once + self._registries = {} + for name in registries_module.registries: + registry = getattr(registries_module, name) + self._registries[registry.interface] = registry.impls + self._initialized = True + + def get_extension(self, interface: Any, impl_name: str) -> Any: + """ + Get the extension implementation for the interface. + + :param interface: Interface class. + :type interface: Any + :param impl_name: Implementation name. + :type impl_name: str + :return: Extension implementation class. + :rtype: Any + :raises ExtensionError: If the interface or implementation is not found. + """ + # Get the registry for the interface + impls = self._registries.get(interface) + print("value is ", impls, interface) + if not impls: + raise ExtensionError(f"Interface '{interface.__name__}' is not supported.") + + # Get the full name of the implementation + full_name = impls.get(impl_name) + if not full_name: + raise ExtensionError(f"Implementation '{impl_name}' for interface '{interface.__name__}' is not exist.") + + try: + # Split the full name into module and class + module_name, class_name = full_name.rsplit(".", 1) + + # Load the module and get the class + module = importlib.import_module(module_name) + subclass = getattr(module, class_name) + + # Return the subclass + return subclass + except Exception as e: + raise ExtensionError( + f"Failed to load extension '{impl_name}' for interface '{interface.__name__}'. \nDetail: {e}" + ) @dataclass @@ -39,7 +119,7 @@ class ExtendedRegistry: impls: dict[str, Any] -# All Extension Registries +# All Extension Registries - FIXED: Added codecRegistry to the list registries = [ "registryFactoryRegistry", "loadBalanceRegistry", @@ -47,6 +127,7 @@ class ExtendedRegistry: "compressorRegistry", "decompressorRegistry", "transporterRegistry", + "codecRegistry", ] # RegistryFactory registry @@ -84,7 +165,6 @@ class ExtendedRegistry: }, ) - # Decompressor registry decompressorRegistry = ExtendedRegistry( interface=Decompressor, @@ -95,7 +175,6 @@ class ExtendedRegistry: }, ) - # Transporter registry transporterRegistry = ExtendedRegistry( interface=Transporter, @@ -103,3 +182,12 @@ class ExtendedRegistry: "aio": "dubbo.remoting.aio.aio_transporter.AioTransporter", }, ) + +# Codec Registry +codecRegistry = ExtendedRegistry( + interface=Codec, + impls={ + "json": "dubbo.codec.json_codec.JsonTransportCodec", + "protobuf": "dubbo.codec.protobuf_codec.ProtobufTransportCodec", + }, +) \ No newline at end of file From 9cd492e97ee0e10845bbfcad2e5c0223f5f3aeeb Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 23 Aug 2025 20:15:00 +0000 Subject: [PATCH 06/40] remove some unneccary debug logic --- src/dubbo/codec/dubbo_codec.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index 640d397..8758684 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -82,8 +82,6 @@ def create_encoder_decoder_pair(transport_type: str, parameter_types: List[Type] return_type=return_type, **codec_options ) - print("codec_instance", codec_instance.get_encoder(), codec_instance.get_decoder()) - return codec_instance.get_encoder(), codec_instance.get_decoder() @staticmethod @@ -104,7 +102,6 @@ def serialize_method_parameters(*args) -> bytes: def deserialize_method_return(data: bytes): return return_decoder.decode(data) - print(type(serialize_method_parameters),type(deserialize_method_return)) return serialize_method_parameters, deserialize_method_return @staticmethod From 6e13e9f65d07bf910545bb892e21c266a6577dcd Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sun, 24 Aug 2025 18:13:13 +0000 Subject: [PATCH 07/40] add the license template in the handlers.py --- src/dubbo/proxy/handlers.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index 8539d63..aa5004f 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import inspect from typing import Callable, Optional, List, Type, Any, get_type_hints from dubbo.classes import MethodDescriptor From 85d0bdd5dbc74e3ddefd6c0fc93843898d0b7498 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:00:58 +0000 Subject: [PATCH 08/40] add the simple test for json and protobuf --- tests/json/json_test.py | 52 +++++++++++++++++++++ tests/protobuf/generated/__init__.py | 0 tests/protobuf/generated/protobuf_test.py | 16 +++++++ tests/protobuf/greet.proto | 11 +++++ tests/protobuf/protobuf_test.py | 57 +++++++++++++++++++++++ 5 files changed, 136 insertions(+) create mode 100644 tests/json/json_test.py create mode 100644 tests/protobuf/generated/__init__.py create mode 100644 tests/protobuf/generated/protobuf_test.py create mode 100644 tests/protobuf/greet.proto create mode 100644 tests/protobuf/protobuf_test.py diff --git a/tests/json/json_test.py b/tests/json/json_test.py new file mode 100644 index 0000000..936a2bd --- /dev/null +++ b/tests/json/json_test.py @@ -0,0 +1,52 @@ +import pytest +from datetime import datetime +from decimal import Decimal +from uuid import uuid4 + +from dubbo.codec.json_codec.json_codec_handler import JsonTransportCodec + +def test_json_single_parameter_roundtrip(): + codec = JsonTransportCodec(parameter_types=[int], return_type=int) + + # Encode a single int + encoded = codec.encode_parameters(42) + assert isinstance(encoded, bytes) + + # Decode back + decoded = codec.decode_return_value(encoded) + assert decoded == 42 + + +def test_json_multiple_parameters_roundtrip(): + codec = JsonTransportCodec(parameter_types=[str, int], return_type=str) + + # Encode multiple args + encoded = codec.encode_parameters("hello", 123) + assert isinstance(encoded, bytes) + + # Decode return (simulate server returning str) + return_encoded = codec.get_encoder().encode(("world",)) + decoded = codec.decode_return_value(return_encoded) + assert decoded == "world" + + +def test_json_complex_types(): + codec = JsonTransportCodec(parameter_types=[dict], return_type=dict) + + obj = { + "name": "Alice", + "when": datetime(2025, 8, 27, 12, 30), + "price": Decimal("19.99"), + "ids": {uuid4(), uuid4()} + } + + encoded = codec.encode_parameters(obj) + assert isinstance(encoded, bytes) + + decoded = codec.decode_return_value(encoded) + assert isinstance(decoded, dict) + assert decoded["name"] == "Alice" + assert isinstance(decoded["price"], Decimal) + assert isinstance(decoded["when"], datetime) + assert isinstance(decoded["ids"], set) + diff --git a/tests/protobuf/generated/__init__.py b/tests/protobuf/generated/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/protobuf/generated/protobuf_test.py b/tests/protobuf/generated/protobuf_test.py new file mode 100644 index 0000000..d8aa649 --- /dev/null +++ b/tests/protobuf/generated/protobuf_test.py @@ -0,0 +1,16 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# sources: greet.proto +# plugin: python-betterproto +from dataclasses import dataclass + +import betterproto + + +@dataclass +class GreeterRequest(betterproto.Message): + name: str = betterproto.string_field(1) + + +@dataclass +class GreeterReply(betterproto.Message): + message: str = betterproto.string_field(1) diff --git a/tests/protobuf/greet.proto b/tests/protobuf/greet.proto new file mode 100644 index 0000000..5b453a7 --- /dev/null +++ b/tests/protobuf/greet.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +package protobuf_test; + +message GreeterRequest { + string name = 1; +} + +message GreeterReply { + string message = 1; +} diff --git a/tests/protobuf/protobuf_test.py b/tests/protobuf/protobuf_test.py new file mode 100644 index 0000000..86677ee --- /dev/null +++ b/tests/protobuf/protobuf_test.py @@ -0,0 +1,57 @@ +import pytest +from dubbo.codec.protobuf_codec import ProtobufTransportCodec +from generated.protobuf_test import GreeterReply, GreeterRequest + + +def test_protobuf_roundtrip_message(): + codec = ProtobufTransportCodec( + parameter_type=GreeterRequest, + return_type=GreeterReply + ) + + # Create a request + req = GreeterRequest(name="Alice") + + # Encode + encoded = codec.encode_parameter(req) + assert isinstance(encoded, bytes) + + # Fake a server reply + reply = GreeterReply(message="Hello Alice") + reply_bytes = bytes(reply) + + # Decode return value + decoded = codec.decode_return_value(reply_bytes) + assert isinstance(decoded, GreeterReply) + assert decoded.message == "Hello Alice" + + +def test_protobuf_from_dict(): + codec = ProtobufTransportCodec( + parameter_type=GreeterRequest, + return_type=GreeterReply + ) + + # Dict instead of message instance + encoded = codec.encode_parameter({"name": "Bob"}) + assert isinstance(encoded, bytes) + + # Decode back to message + req = codec._decoder.decode(encoded) # simulate server echo + assert isinstance(req, GreeterRequest) + assert req.name == "Bob" + + +def test_protobuf_primitive_fallback(): + codec = ProtobufTransportCodec( + parameter_type=str, + return_type=str + ) + + encoded = codec.encode_parameter("simple string") + assert isinstance(encoded, bytes) + + # Decode back + decoded = codec.decode_return_value(encoded) + assert isinstance(decoded, str) + assert decoded == "simple string" From b94f3cf1edec8fd794249b1f65e993c3ffddead7 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Wed, 27 Aug 2025 09:09:22 +0000 Subject: [PATCH 09/40] add the license header for missing or newly generated file --- tests/json/json_test.py | 16 ++++++++++++++++ tests/protobuf/generated/__init__.py | 15 +++++++++++++++ tests/protobuf/generated/protobuf_test.py | 16 ++++++++++++++++ tests/protobuf/greet.proto | 16 ++++++++++++++++ tests/protobuf/protobuf_test.py | 16 ++++++++++++++++ 5 files changed, 79 insertions(+) diff --git a/tests/json/json_test.py b/tests/json/json_test.py index 936a2bd..0bbbc04 100644 --- a/tests/json/json_test.py +++ b/tests/json/json_test.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest from datetime import datetime from decimal import Decimal diff --git a/tests/protobuf/generated/__init__.py b/tests/protobuf/generated/__init__.py index e69de29..4f1421a 100644 --- a/tests/protobuf/generated/__init__.py +++ b/tests/protobuf/generated/__init__.py @@ -0,0 +1,15 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/tests/protobuf/generated/protobuf_test.py b/tests/protobuf/generated/protobuf_test.py index d8aa649..38bd7f6 100644 --- a/tests/protobuf/generated/protobuf_test.py +++ b/tests/protobuf/generated/protobuf_test.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Generated by the protocol buffer compiler. DO NOT EDIT! # sources: greet.proto # plugin: python-betterproto diff --git a/tests/protobuf/greet.proto b/tests/protobuf/greet.proto index 5b453a7..9c16bbc 100644 --- a/tests/protobuf/greet.proto +++ b/tests/protobuf/greet.proto @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + syntax = "proto3"; package protobuf_test; diff --git a/tests/protobuf/protobuf_test.py b/tests/protobuf/protobuf_test.py index 86677ee..a65492d 100644 --- a/tests/protobuf/protobuf_test.py +++ b/tests/protobuf/protobuf_test.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import pytest from dubbo.codec.protobuf_codec import ProtobufTransportCodec from generated.protobuf_test import GreeterReply, GreeterRequest From 46e660c34b9f604c8d01ce0a3967dc1ae2f8c52f Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sun, 31 Aug 2025 22:32:39 +0000 Subject: [PATCH 10/40] fixed some naming convenion , clean the code and fix some bug --- src/dubbo/client.py | 224 ++++++++++++----------------------- tests/json/json_type_test.py | 87 ++++++++++++++ 2 files changed, 165 insertions(+), 146 deletions(-) create mode 100644 tests/json/json_type_test.py diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 7dc215a..f79d4af 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -8,14 +8,15 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading -from typing import Optional, Callable, List, Type, Union, Any +import inspect +from typing import Optional, Callable, List, Type, Any, get_type_hints from dubbo.bootstrap import Dubbo from dubbo.classes import MethodDescriptor @@ -63,11 +64,17 @@ def _initialize(self): return # get the protocol - protocol = extensionLoader.get_extension(Protocol, self._reference.protocol)() + protocol = extensionLoader.get_extension( + Protocol, self._reference.protocol + )() registry_config = self._dubbo.registry_config - self._protocol = RegistryProtocol(registry_config, protocol) if self._dubbo.registry_config else protocol + self._protocol = ( + RegistryProtocol(registry_config, protocol) + if registry_config + else protocol + ) # build url reference_url = self._reference.to_url() @@ -84,6 +91,28 @@ def _initialize(self): self._initialized = True + @classmethod + def _infer_types_from_interface(cls, interface: Callable) -> tuple: + """ + Infer method name, parameter types, and return type from a callable. + """ + try: + type_hints = get_type_hints(interface) + sig = inspect.signature(interface) + method_name = interface.__name__ + params = list(sig.parameters.values()) + + # skip 'self' for bound methods + if params and params[0].name == "self": + params = params[1:] + + param_types = [type_hints.get(p.name, Any) for p in params] + return_type = type_hints.get("return", Any) + + return method_name, param_types, return_type + except Exception: + return interface.__name__, [Any], Any + def _create_rpc_callable( self, rpc_type: str, @@ -99,176 +128,79 @@ def _create_rpc_callable( """ Create RPC callable with the specified type. """ - # Validate if interface is None and method_name is None: raise ValueError("Either 'interface' or 'method_name' must be provided") - # Determine the actual method name to call - actual_method_name = method_name or (interface.__name__ if interface else default_method_name) - - # Build method descriptor (automatic or manual) + # Start with explicit values + m_name = method_name + p_types = params_types + r_type = return_type + + # Infer from interface if needed if interface: - method_desc = DubboTransportService.create_method_descriptor( - func=interface, - method_name=actual_method_name, - parameter_types=params_types, - return_type=return_type, - interface=interface, - ) + if p_types is None or r_type is None or m_name is None: + inf_name, inf_params, inf_return = self._infer_types_from_interface( + interface + ) + m_name = m_name or inf_name + p_types = p_types or inf_params + r_type = r_type or inf_return + + # Fallback to default + m_name = m_name or default_method_name + + # Determine serializers + if request_serializer and response_deserializer: + req_ser = request_serializer + res_deser = response_deserializer else: - # Manual mode fallback: use dummy function for descriptor creation - def dummy(): pass - - method_desc = DubboTransportService.create_method_descriptor( - func=dummy, - method_name=actual_method_name, - parameter_types=params_types or [], - return_type=return_type or Any, + req_ser, res_deser = DubboTransportService.create_serialization_functions( + codec or "json", # fallback to json + parameter_types=p_types, + return_type=r_type, ) - # Determine serializers if not provided - if request_serializer and response_deserializer: - final_request_serializer = request_serializer - final_response_deserializer = response_deserializer - else: - # Use DubboTransportService to generate serialization functions - final_request_serializer, final_response_deserializer = DubboTransportService.create_serialization_functions(codec, parameter_types=params_types, return_type=return_type) - print("final",codec, final_request_serializer, final_response_deserializer) - - # Create the proper MethodDescriptor for the RPC call - rpc_method_descriptor = MethodDescriptor( - method_name=actual_method_name, - arg_serialization=(final_request_serializer, None), # (serializer, deserializer) for arguments - return_serialization=(None, final_response_deserializer), # (serializer, deserializer) for return value + # Create MethodDescriptor + descriptor = MethodDescriptor( + method_name=m_name, + arg_serialization=(req_ser, None), + return_serialization=(None, res_deser), rpc_type=rpc_type, ) - # Create and return the RpcCallable - return self._callable(rpc_method_descriptor) - - def unary( - self, - interface: Optional[Callable] = None, - method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, - codec: Optional[str] = None, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None, - ) -> RpcCallable: - """ - Create unary RPC call. + return self._callable(descriptor) - Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). - """ + def unary(self, **kwargs) -> RpcCallable: return self._create_rpc_callable( - rpc_type=RpcTypes.UNARY.value, - interface=interface, - method_name=method_name, - params_types=params_types, - return_type=return_type, - codec=codec, - request_serializer=request_serializer, - response_deserializer=response_deserializer, - default_method_name="unary", + rpc_type=RpcTypes.UNARY.value, default_method_name="unary", **kwargs ) - def client_stream( - self, - interface: Optional[Callable] = None, - method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, - codec: Optional[str] = None, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None, - ) -> RpcCallable: - """ - Create client streaming RPC call. - - Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). - """ + def client_stream(self, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.CLIENT_STREAM.value, - interface=interface, - method_name=method_name, - params_types=params_types, - return_type=return_type, - codec=codec, - request_serializer=request_serializer, - response_deserializer=response_deserializer, default_method_name="client_stream", + **kwargs, ) - def server_stream( - self, - interface: Optional[Callable] = None, - method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, - codec: Optional[str] = None, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None, - ) -> RpcCallable: - """ - Create server streaming RPC call. - - Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). - """ + def server_stream(self, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.SERVER_STREAM.value, - interface=interface, - method_name=method_name, - params_types=params_types, - return_type=return_type, - codec=codec, - request_serializer=request_serializer, - response_deserializer=response_deserializer, default_method_name="server_stream", + **kwargs, ) - def bi_stream( - self, - interface: Optional[Callable] = None, - method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, - codec: Optional[str] = None, - request_serializer: Optional[SerializingFunction] = None, - response_deserializer: Optional[DeserializingFunction] = None, - ) -> RpcCallable: - """ - Create bidirectional streaming RPC call. - - Supports both automatic mode (via interface) and manual mode (via method_name + params_types + return_type + codec). - """ + def bi_stream(self, **kwargs) -> RpcCallable: return self._create_rpc_callable( - rpc_type=RpcTypes.BI_STREAM.value, - interface=interface, - method_name=method_name, - params_types=params_types, - return_type=return_type, - codec=codec, - request_serializer=request_serializer, - response_deserializer=response_deserializer, - default_method_name="bi_stream", + rpc_type=RpcTypes.BI_STREAM.value, default_method_name="bi_stream", **kwargs ) def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: """ Generate a proxy for the given method. - :param method_descriptor: The method descriptor. - :return: The proxy. - :rtype: RpcCallable """ - # get invoker - url = self._invoker.get_url() - - # clone url - url = url.copy() - url.parameters[common_constants.METHOD_KEY] = method_descriptor.get_method_name() - # set method descriptor + url = self._invoker.get_url().copy() + url.parameters[common_constants.METHOD_KEY] = ( + method_descriptor.get_method_name() + ) url.attributes[common_constants.METHOD_DESCRIPTOR_KEY] = method_descriptor - - # create proxy - return self._callable_factory.get_callable(self._invoker, url) \ No newline at end of file + return self._callable_factory.get_callable(self._invoker, url) diff --git a/tests/json/json_type_test.py b/tests/json/json_type_test.py new file mode 100644 index 0000000..bb7300a --- /dev/null +++ b/tests/json/json_type_test.py @@ -0,0 +1,87 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pathlib import Path +from uuid import UUID +from decimal import Decimal +from datetime import datetime, date, time +from dataclasses import dataclass +from enum import Enum +from pydantic import BaseModel + +from dubbo.codec.json_codec.json_codec_handler import JsonTransportCodec + +# Optional dataclass and enum examples +@dataclass +class SampleDataClass: + field1: int + field2: str + +class Color(Enum): + RED = "red" + GREEN = "green" + +class SamplePydanticModel(BaseModel): + name: str + value: int + +# List of test cases: (input_value, expected_type_after_decoding) +test_cases = [ + ("simple string", str), + (12345, int), + (12.34, float), + (True, bool), + (datetime(2025, 8, 27, 13, 0, 0), datetime), + (date(2025, 8, 27), date), + (time(13, 0, 0), time), + (Decimal("123.45"), Decimal), + (set([1, 2, 3]), set), + (frozenset(["a", "b"]), frozenset), + (UUID("12345678-1234-5678-1234-567812345678"), UUID), + (Path("/tmp/file.txt"), Path), + (SampleDataClass(1, "abc"), SampleDataClass), + (Color.RED, Color), + (SamplePydanticModel(name="test", value=42), SamplePydanticModel) +] + +@pytest.mark.parametrize("value,expected_type", test_cases) +def test_json_codec_roundtrip(value, expected_type): + codec = JsonTransportCodec(parameter_types=[type(value)], return_type=type(value)) + + # Encode + encoded = codec.encode_parameters(value) + assert isinstance(encoded, bytes) + + # Decode + decoded = codec.decode_return_value(encoded) + + # For pydantic models, compare dict representation + if hasattr(value, "dict") and callable(value.dict): + assert decoded.dict() == value.dict() + # For dataclass, compare asdict + elif hasattr(value, "__dataclass_fields__"): + from dataclasses import asdict + assert asdict(decoded) == asdict(value) + # For sets/frozensets, compare as sets + elif isinstance(value, (set, frozenset)): + assert decoded == value + # For enum + elif isinstance(value, Enum): + assert decoded.value == value.value + else: + assert decoded == value + From 1ff70836b4f452575c81864a5fe2f4d8a07d10b5 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Mon, 1 Sep 2025 17:45:57 +0000 Subject: [PATCH 11/40] completed the hessian with pyhessian and manual serialization --- .../hessian_codec/hessian_codec_handler.py | 185 ++++++++++++++++++ .../codec/hessian_codec/manual_hessian.py | 120 ++++++++++++ 2 files changed, 305 insertions(+) create mode 100644 src/dubbo/codec/hessian_codec/hessian_codec_handler.py create mode 100644 src/dubbo/codec/hessian_codec/manual_hessian.py diff --git a/src/dubbo/codec/hessian_codec/hessian_codec_handler.py b/src/dubbo/codec/hessian_codec/hessian_codec_handler.py new file mode 100644 index 0000000..239cc64 --- /dev/null +++ b/src/dubbo/codec/hessian_codec/hessian_codec_handler.py @@ -0,0 +1,185 @@ +import io +import struct +import logging +from typing import Any, List, Dict, Optional, Union, Type +from abc import ABC + +try: + from pyhessian import Hessian2Input, Hessian2Output + + _HAS_PYHESSIAN = True +except ImportError: + _HAS_PYHESSIAN = False + from .manual_hessian import Hessian2Input, Hessian2Output + + + +class HessianTypeError(Exception): + """Exception raised for type validation errors in Hessian RPC.""" + pass + + +class HessianRpcError(Exception): + """Base exception for Hessian RPC errors.""" + pass + + +class TypeValidator: + """Type validation utilities for Hessian RPC.""" + + @staticmethod + def validate_type(value: Any, expected_type: Union[str, Type]) -> bool: + """Validate if a value matches the expected type.""" + if isinstance(expected_type, str): + return TypeValidator._validate_string_type(value, expected_type) + else: + return isinstance(value, expected_type) + + @staticmethod + def _validate_string_type(value: Any, type_string: str) -> bool: + """Validate value against string type specification.""" + type_mapping = { + 'str': str, 'int': int, 'float': float, 'bool': bool, + 'list': list, 'dict': dict, 'bytes': bytes, + 'none': type(None), 'any': object + } + + if type_string.lower() in type_mapping: + expected_type = type_mapping[type_string.lower()] + return isinstance(value, expected_type) + + if '[' in type_string and ']' in type_string: + base_type = type_string.split('[')[0].lower() + if base_type == 'list': + return isinstance(value, list) + elif base_type == 'dict': + return isinstance(value, dict) + + return True + + @staticmethod + def validate_parameters(args: tuple, param_types: List[Union[str, Type]]) -> None: + """Validate method parameters against expected types.""" + if len(args) != len(param_types): + raise HessianTypeError( + f"Parameter count mismatch: expected {len(param_types)}, got {len(args)}" + ) + + for i, (arg, expected_type) in enumerate(zip(args, param_types)): + if not TypeValidator.validate_type(arg, expected_type): + raise HessianTypeError( + f"Parameter {i} type mismatch: expected {expected_type}, got {type(arg).__name__}" + ) + + +class TypeProvider: + """Provides type information and serialization hints for custom objects.""" + + def __init__(self): + self._type_registry: Dict[str, Type] = {} + + def register_type(self, type_name: str, type_class: Type) -> None: + """Register a custom type for serialization.""" + self._type_registry[type_name] = type_class + + def get_type(self, type_name: str) -> Optional[Type]: + """Get a registered type by name.""" + return self._type_registry.get(type_name) + + + +class HessianTransportEncoder: + """Encodes Python objects into Hessian binary format with type validation.""" + + def __init__(self, parameter_types: Optional[List[Union[str, Type]]] = None): + self.parameter_types = parameter_types or [] + self._type_provider: Optional[TypeProvider] = None + self._logger = logging.getLogger(__name__) + + def encode(self, arguments: tuple) -> bytes: + if self.parameter_types: + TypeValidator.validate_parameters(arguments, self.parameter_types) + + try: + output_stream = io.BytesIO() + hessian_output = Hessian2Output(output_stream) + + for arg in arguments: + hessian_output.write_object(arg) + + return output_stream.getvalue() + except Exception as e: + self._logger.error(f"Encoding error: {e}") + raise HessianRpcError(f"Failed to encode parameters: {e}") + + def register_type_provider(self, provider: TypeProvider) -> None: + """Register a type provider for custom type handling.""" + self._type_provider = provider + + +class HessianTransportDecoder: + """Decodes Hessian binary format into Python objects with type validation.""" + + def __init__(self, target_type: Optional[Union[str, Type]] = None): + self.target_type = target_type + self._type_provider: Optional[TypeProvider] = None + self._logger = logging.getLogger(__name__) + + def decode(self, data: bytes) -> Any: + try: + input_stream = io.BytesIO(data) + hessian_input = Hessian2Input(input_stream) + result = hessian_input.read_object() + + if self.target_type: + if not TypeValidator.validate_type(result, self.target_type): + raise HessianTypeError( + f"Return type mismatch: expected {self.target_type}, got {type(result).__name__}" + ) + + return result + except Exception as e: + self._logger.error(f"Decoding error: {e}") + raise HessianRpcError(f"Failed to decode data: {e}") + + def register_type_provider(self, provider: TypeProvider) -> None: + """Register a type provider for custom type handling.""" + self._type_provider = provider + + +class HessianTransportCodec(ABC): + """High-level encoder/decoder wrapper for Hessian RPC with enhanced features.""" + + def __init__(self, parameter_types: Optional[List[Union[str, Type]]] = None, + return_type: Optional[Union[str, Type]] = None): + self.parameter_types = parameter_types or [] + self.return_type = return_type + self._encoder = HessianTransportEncoder(parameter_types) + self._decoder = HessianTransportDecoder(return_type) + self._logger = logging.getLogger(__name__) + + def encode_parameters(self, *arguments) -> bytes: + """Encode method parameters to Hessian binary format.""" + return self._encoder.encode(arguments) + + def decode_return_value(self, data: bytes) -> Any: + """Decode return value from Hessian binary format.""" + return self._decoder.decode(data) + + def validate_call(self, *arguments) -> None: + """Validate method call parameters without encoding.""" + if self.parameter_types: + TypeValidator.validate_parameters(arguments, self.parameter_types) + + def register_type_provider(self, provider: TypeProvider) -> None: + """Register a type provider for both encoder and decoder.""" + self._encoder.register_type_provider(provider) + self._decoder.register_type_provider(provider) + + def get_encoder(self) -> HessianTransportEncoder: + """Get the encoder instance.""" + return self._encoder + + def get_decoder(self) -> HessianTransportDecoder: + """Get the decoder instance.""" + return self._decoder diff --git a/src/dubbo/codec/hessian_codec/manual_hessian.py b/src/dubbo/codec/hessian_codec/manual_hessian.py new file mode 100644 index 0000000..7bf6487 --- /dev/null +++ b/src/dubbo/codec/hessian_codec/manual_hessian.py @@ -0,0 +1,120 @@ +import io +import struct +from typing import Any + +class Hessian2Output: + def __init__(self, stream: io.BytesIO): + self.stream = stream + + def write_object(self, obj: Any): + if obj is None: + self.stream.write(b'N') # null + elif isinstance(obj, bool): + self.stream.write(b'T' if obj else b'F') + elif isinstance(obj, int): + if -16 <= obj <= 47: + self.stream.write(bytes([0x90 + obj])) + elif -2048 <= obj <= 2047: + self.stream.write(bytes([0xc8 + (obj >> 8), obj & 0xff])) + else: + self.stream.write(b'I' + struct.pack(">i", obj)) + elif isinstance(obj, float): + self.stream.write(b'D' + struct.pack(">d", obj)) + elif isinstance(obj, str): + encoded = obj.encode("utf-8") + length = len(encoded) + if length <= 31: + self.stream.write(bytes([length]) + encoded) + else: + self.stream.write(b'S' + struct.pack(">H", length) + encoded) + elif isinstance(obj, bytes): + length = len(obj) + self.stream.write(b'B' + struct.pack(">H", length) + obj) + elif isinstance(obj, list): + self.stream.write(b'V') + for item in obj: + self.write_object(item) + self.stream.write(b'Z') + elif isinstance(obj, dict): + self.stream.write(b'M') + for k, v in obj.items(): + self.write_object(k) + self.write_object(v) + self.stream.write(b'Z') + else: + raise TypeError(f"Unsupported type for Hessian encoding: {type(obj)}") + + +class Hessian2Input: + def __init__(self, stream: io.BytesIO): + self.stream = stream + + def _read(self, n: int) -> bytes: + data = self.stream.read(n) + if len(data) != n: + raise ValueError(f"Expected {n} bytes, got {len(data)}") + return data + + def _read_byte(self) -> int: + data = self.stream.read(1) + if not data: + raise ValueError("Unexpected end of stream") + return data[0] + + def read_object(self) -> Any: + tag = self._read_byte() + + if tag == ord('N'): + return None + + elif tag == ord('T'): + return True + elif tag == ord('F'): + return False + + elif tag == ord('I'): + return struct.unpack(">i", self._read(4))[0] + elif 0x90 <= tag <= 0xbf: + return tag - 0x90 + elif 0xc0 <= tag <= 0xcf: + return ((tag - 0xc8) << 8) + self._read_byte() + elif 0xd0 <= tag <= 0xd7: + return ((tag - 0xd4) << 16) + (self._read_byte() << 8) + self._read_byte() + + elif tag == ord('D'): + return struct.unpack(">d", self._read(8))[0] + + elif tag == ord('S'): + length = struct.unpack(">H", self._read(2))[0] + return self._read(length).decode("utf-8") + elif 0x00 <= tag <= 0x1f: + return self._read(tag).decode("utf-8") + + elif tag == ord('B'): + length = struct.unpack(">H", self._read(2))[0] + return self._read(length) + + elif tag == ord('V'): + arr = [] + while True: + peek = self.stream.read(1) + if peek == b'Z': + break + self.stream.seek(-1, 1) + arr.append(self.read_object()) + return arr + + elif tag == ord('M'): + d = {} + while True: + peek = self.stream.read(1) + if peek == b'Z': + break + self.stream.seek(-1, 1) + k = self.read_object() + v = self.read_object() + d[k] = v + return d + + else: + raise ValueError(f"Unknown Hessian tag: {hex(tag)}") From dfa1b0793a1be133fcbed932301f7890e3cf61af Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Mon, 1 Sep 2025 18:40:42 +0000 Subject: [PATCH 12/40] upgraded the manual , automatic fallback to manual if the type cant be handled by the pyhessian --- .../hessian_codec/hessian_codec_handler.py | 180 +++++++-- .../codec/hessian_codec/manual_hessian.py | 350 +++++++++++++++--- 2 files changed, 447 insertions(+), 83 deletions(-) diff --git a/src/dubbo/codec/hessian_codec/hessian_codec_handler.py b/src/dubbo/codec/hessian_codec/hessian_codec_handler.py index 239cc64..851b1ab 100644 --- a/src/dubbo/codec/hessian_codec/hessian_codec_handler.py +++ b/src/dubbo/codec/hessian_codec/hessian_codec_handler.py @@ -1,17 +1,19 @@ import io -import struct import logging from typing import Any, List, Dict, Optional, Union, Type from abc import ABC +from datetime import datetime, date, time +from decimal import Decimal +import uuid + +# Manual serializer for types pyhessian can't handle +from .manual_hessian import Hessian2Output as output, Hessian2Input as input try: from pyhessian import Hessian2Input, Hessian2Output - _HAS_PYHESSIAN = True except ImportError: _HAS_PYHESSIAN = False - from .manual_hessian import Hessian2Input, Hessian2Output - class HessianTypeError(Exception): @@ -40,7 +42,9 @@ def _validate_string_type(value: Any, type_string: str) -> bool: """Validate value against string type specification.""" type_mapping = { 'str': str, 'int': int, 'float': float, 'bool': bool, - 'list': list, 'dict': dict, 'bytes': bytes, + 'list': list, 'dict': dict, 'bytes': bytes, 'tuple': tuple, + 'set': set, 'frozenset': frozenset, 'datetime': datetime, + 'date': date, 'time': time, 'decimal': Decimal, 'uuid': uuid.UUID, 'none': type(None), 'any': object } @@ -87,9 +91,53 @@ def get_type(self, type_name: str) -> Optional[Type]: return self._type_registry.get(type_name) +class HybridHessianHandler: + """Handles decision-making between pyhessian and manual serialization.""" + + # Types that pyhessian typically can't handle well + MANUAL_TYPES = { + set, frozenset, tuple, Decimal, uuid.UUID, date, time, + complex, range, memoryview, bytearray + } + + # Custom objects (non-builtin types) + BUILTIN_TYPES = { + str, int, float, bool, list, dict, bytes, type(None), + datetime # pyhessian can handle datetime + } + + @classmethod + def should_use_manual(cls, obj: Any) -> bool: + """Determine if an object should use manual serialization.""" + obj_type = type(obj) + + # Check if it's a type we know needs manual handling + if obj_type in cls.MANUAL_TYPES: + return True + + # Check if it's a custom class (not a builtin) + if obj_type not in cls.BUILTIN_TYPES and hasattr(obj, '__dict__'): + return True + + # For containers, check their contents + if isinstance(obj, (list, tuple)): + return any(cls.should_use_manual(item) for item in obj) + elif isinstance(obj, dict): + return any(cls.should_use_manual(k) or cls.should_use_manual(v) + for k, v in obj.items()) + elif isinstance(obj, (set, frozenset)): + return True # sets always need manual handling + + return False + + @classmethod + def contains_manual_types(cls, args: tuple) -> bool: + """Check if any argument requires manual serialization.""" + return any(cls.should_use_manual(arg) for arg in args) + class HessianTransportEncoder: - """Encodes Python objects into Hessian binary format with type validation.""" + """Encodes Python objects using hybrid pyhessian/manual approach.""" def __init__(self, parameter_types: Optional[List[Union[str, Type]]] = None): self.parameter_types = parameter_types or [] @@ -97,20 +145,49 @@ def __init__(self, parameter_types: Optional[List[Union[str, Type]]] = None): self._logger = logging.getLogger(__name__) def encode(self, arguments: tuple) -> bytes: + """Encode arguments using the most appropriate serializer.""" if self.parameter_types: TypeValidator.validate_parameters(arguments, self.parameter_types) try: - output_stream = io.BytesIO() - hessian_output = Hessian2Output(output_stream) - - for arg in arguments: - hessian_output.write_object(arg) - - return output_stream.getvalue() + # Decide which serializer to use + use_manual = not _HAS_PYHESSIAN or HybridHessianHandler.contains_manual_types(arguments) + + if use_manual: + return self._encode_manual(arguments) + else: + return self._encode_pyhessian(arguments) + except Exception as e: self._logger.error(f"Encoding error: {e}") + # Fallback to manual if pyhessian fails + if not use_manual and _HAS_PYHESSIAN: + self._logger.info("Falling back to manual serialization") + try: + return self._encode_manual(arguments) + except Exception as fallback_error: + raise HessianRpcError(f"Both serializers failed. Last error: {fallback_error}") raise HessianRpcError(f"Failed to encode parameters: {e}") + + def _encode_pyhessian(self, arguments: tuple) -> bytes: + """Encode using pyhessian library.""" + output_stream = io.BytesIO() + hessian_output = Hessian2Output(output_stream) + + for arg in arguments: + hessian_output.write_object(arg) + + return output_stream.getvalue() + + def _encode_manual(self, arguments: tuple) -> bytes: + """Encode using manual serializer.""" + output_stream = io.BytesIO() + hessian_output = output(output_stream) + + for arg in arguments: + hessian_output.write_object(arg) + + return output_stream.getvalue() def register_type_provider(self, provider: TypeProvider) -> None: """Register a type provider for custom type handling.""" @@ -118,29 +195,67 @@ def register_type_provider(self, provider: TypeProvider) -> None: class HessianTransportDecoder: - """Decodes Hessian binary format into Python objects with type validation.""" + """Decodes Hessian binary format using hybrid approach.""" - def __init__(self, target_type: Optional[Union[str, Type]] = None): + def __init__(self, target_type: Optional[Union[str, Type]] = None, prefer_manual: bool = False): self.target_type = target_type + self.prefer_manual = prefer_manual self._type_provider: Optional[TypeProvider] = None self._logger = logging.getLogger(__name__) def decode(self, data: bytes) -> Any: + """Decode data using the most appropriate deserializer.""" try: - input_stream = io.BytesIO(data) - hessian_input = Hessian2Input(input_stream) - result = hessian_input.read_object() - - if self.target_type: - if not TypeValidator.validate_type(result, self.target_type): - raise HessianTypeError( - f"Return type mismatch: expected {self.target_type}, got {type(result).__name__}" - ) - - return result + # Try the preferred method first + if self.prefer_manual or not _HAS_PYHESSIAN: + return self._decode_manual(data) + else: + return self._decode_pyhessian(data) + except Exception as e: self._logger.error(f"Decoding error: {e}") + # Fallback to the other method + if not self.prefer_manual and _HAS_PYHESSIAN: + self._logger.info("Falling back to manual deserialization") + try: + return self._decode_manual(data) + except Exception as fallback_error: + raise HessianRpcError(f"Both deserializers failed. Last error: {fallback_error}") + elif self.prefer_manual and _HAS_PYHESSIAN: + self._logger.info("Falling back to pyhessian deserialization") + try: + return self._decode_pyhessian(data) + except Exception as fallback_error: + raise HessianRpcError(f"Both deserializers failed. Last error: {fallback_error}") raise HessianRpcError(f"Failed to decode data: {e}") + + def _decode_pyhessian(self, data: bytes) -> Any: + """Decode using pyhessian library.""" + input_stream = io.BytesIO(data) + hessian_input = Hessian2Input(input_stream) + result = hessian_input.read_object() + + if self.target_type: + if not TypeValidator.validate_type(result, self.target_type): + raise HessianTypeError( + f"Return type mismatch: expected {self.target_type}, got {type(result).__name__}" + ) + + return result + + def _decode_manual(self, data: bytes) -> Any: + """Decode using manual deserializer.""" + input_stream = io.BytesIO(data) + hessian_input = input(input_stream) + result = hessian_input.read_object() + + if self.target_type: + if not TypeValidator.validate_type(result, self.target_type): + raise HessianTypeError( + f"Return type mismatch: expected {self.target_type}, got {type(result).__name__}" + ) + + return result def register_type_provider(self, provider: TypeProvider) -> None: """Register a type provider for custom type handling.""" @@ -148,14 +263,16 @@ def register_type_provider(self, provider: TypeProvider) -> None: class HessianTransportCodec(ABC): - """High-level encoder/decoder wrapper for Hessian RPC with enhanced features.""" + """High-level encoder/decoder wrapper with hybrid serialization.""" def __init__(self, parameter_types: Optional[List[Union[str, Type]]] = None, - return_type: Optional[Union[str, Type]] = None): + return_type: Optional[Union[str, Type]] = None, + prefer_manual: bool = False): self.parameter_types = parameter_types or [] self.return_type = return_type + self.prefer_manual = prefer_manual self._encoder = HessianTransportEncoder(parameter_types) - self._decoder = HessianTransportDecoder(return_type) + self._decoder = HessianTransportDecoder(return_type, prefer_manual) self._logger = logging.getLogger(__name__) def encode_parameters(self, *arguments) -> bytes: @@ -183,3 +300,8 @@ def get_encoder(self) -> HessianTransportEncoder: def get_decoder(self) -> HessianTransportDecoder: """Get the decoder instance.""" return self._decoder + + def set_manual_preference(self, prefer_manual: bool)-> None: + """Set preference for manual serialization.""" + self.prefer_manual = prefer_manual + self._decoder.prefer_manual = prefer_manual diff --git a/src/dubbo/codec/hessian_codec/manual_hessian.py b/src/dubbo/codec/hessian_codec/manual_hessian.py index 7bf6487..c0906fb 100644 --- a/src/dubbo/codec/hessian_codec/manual_hessian.py +++ b/src/dubbo/codec/hessian_codec/manual_hessian.py @@ -1,53 +1,194 @@ import io import struct -from typing import Any +from typing import Any, Dict, Type +from datetime import datetime, date, time +from decimal import Decimal +import uuid + class Hessian2Output: + + TYPE_MARKERS = { + 'set': 0xF0, + 'frozenset': 0xF1, + 'tuple': 0xF2, + 'decimal': 0xF3, + 'uuid': 0xF4, + 'date': 0xF5, + 'time': 0xF6, + 'complex': 0xF7, + 'range': 0xF8, + 'bytearray': 0xF9, + 'custom_object': 0xFA + } + def __init__(self, stream: io.BytesIO): self.stream = stream def write_object(self, obj: Any): + """Write any Python object to Hessian format.""" if obj is None: - self.stream.write(b'N') # null + self._write_null() elif isinstance(obj, bool): - self.stream.write(b'T' if obj else b'F') + self._write_bool(obj) elif isinstance(obj, int): - if -16 <= obj <= 47: - self.stream.write(bytes([0x90 + obj])) - elif -2048 <= obj <= 2047: - self.stream.write(bytes([0xc8 + (obj >> 8), obj & 0xff])) - else: - self.stream.write(b'I' + struct.pack(">i", obj)) + self._write_int(obj) elif isinstance(obj, float): - self.stream.write(b'D' + struct.pack(">d", obj)) + self._write_float(obj) elif isinstance(obj, str): - encoded = obj.encode("utf-8") - length = len(encoded) - if length <= 31: - self.stream.write(bytes([length]) + encoded) - else: - self.stream.write(b'S' + struct.pack(">H", length) + encoded) + self._write_string(obj) elif isinstance(obj, bytes): - length = len(obj) - self.stream.write(b'B' + struct.pack(">H", length) + obj) + self._write_bytes(obj) + elif isinstance(obj, bytearray): + self._write_bytearray(obj) elif isinstance(obj, list): - self.stream.write(b'V') - for item in obj: - self.write_object(item) - self.stream.write(b'Z') + self._write_list(obj) + elif isinstance(obj, tuple): + self._write_tuple(obj) elif isinstance(obj, dict): - self.stream.write(b'M') - for k, v in obj.items(): - self.write_object(k) - self.write_object(v) - self.stream.write(b'Z') + self._write_dict(obj) + elif isinstance(obj, set): + self._write_set(obj) + elif isinstance(obj, frozenset): + self._write_frozenset(obj) + elif isinstance(obj, Decimal): + self._write_decimal(obj) + elif isinstance(obj, uuid.UUID): + self._write_uuid(obj) + elif isinstance(obj, datetime): + self._write_datetime(obj) + elif isinstance(obj, date): + self._write_date(obj) + elif isinstance(obj, time): + self._write_time(obj) + elif isinstance(obj, complex): + self._write_complex(obj) + elif isinstance(obj, range): + self._write_range(obj) + else: + self._write_custom_object(obj) + + def _write_null(self): + self.stream.write(b'N') + + def _write_bool(self, obj: bool): + self.stream.write(b'T' if obj else b'F') + + def _write_int(self, obj: int): + if -16 <= obj <= 47: + self.stream.write(bytes([0x90 + obj])) + elif -2048 <= obj <= 2047: + self.stream.write(bytes([0xc8 + (obj >> 8), obj & 0xff])) + else: + self.stream.write(b'I' + struct.pack(">i", obj)) + + def _write_float(self, obj: float): + self.stream.write(b'D' + struct.pack(">d", obj)) + + def _write_string(self, obj: str): + encoded = obj.encode("utf-8") + length = len(encoded) + if length <= 31: + self.stream.write(bytes([length]) + encoded) + else: + self.stream.write(b'S' + struct.pack(">H", length) + encoded) + + def _write_bytes(self, obj: bytes): + length = len(obj) + self.stream.write(b'B' + struct.pack(">H", length) + obj) + + def _write_bytearray(self, obj: bytearray): + self.stream.write(bytes([self.TYPE_MARKERS['bytearray']])) + self._write_bytes(bytes(obj)) + + def _write_list(self, obj: list): + self.stream.write(b'V') + for item in obj: + self.write_object(item) + self.stream.write(b'Z') + + def _write_tuple(self, obj: tuple): + self.stream.write(bytes([self.TYPE_MARKERS['tuple']])) + self.stream.write(struct.pack(">I", len(obj))) + for item in obj: + self.write_object(item) + + def _write_dict(self, obj: dict): + self.stream.write(b'M') + for k, v in obj.items(): + self.write_object(k) + self.write_object(v) + self.stream.write(b'Z') + + def _write_set(self, obj: set): + self.stream.write(bytes([self.TYPE_MARKERS['set']])) + self.stream.write(struct.pack(">I", len(obj))) + for item in obj: + self.write_object(item) + + def _write_frozenset(self, obj: frozenset): + self.stream.write(bytes([self.TYPE_MARKERS['frozenset']])) + self.stream.write(struct.pack(">I", len(obj))) + for item in obj: + self.write_object(item) + + def _write_decimal(self, obj: Decimal): + self.stream.write(bytes([self.TYPE_MARKERS['decimal']])) + self._write_string(str(obj)) + + def _write_uuid(self, obj: uuid.UUID): + self.stream.write(bytes([self.TYPE_MARKERS['uuid']])) + self._write_bytes(obj.bytes) + + def _write_datetime(self, obj: datetime): + # Use standard Hessian date format + timestamp = int(obj.timestamp() * 1000) + self.stream.write(b'd' + struct.pack(">q", timestamp)) + + def _write_date(self, obj: date): + self.stream.write(bytes([self.TYPE_MARKERS['date']])) + # Store as year, month, day + self.stream.write(struct.pack(">HBB", obj.year, obj.month, obj.day)) + + def _write_time(self, obj: time): + self.stream.write(bytes([self.TYPE_MARKERS['time']])) + # Store as hour, minute, second, microsecond + self.stream.write(struct.pack(">BBBI", obj.hour, obj.minute, obj.second, obj.microsecond)) + + def _write_complex(self, obj: complex): + self.stream.write(bytes([self.TYPE_MARKERS['complex']])) + self.stream.write(struct.pack(">dd", obj.real, obj.imag)) + + def _write_range(self, obj: range): + self.stream.write(bytes([self.TYPE_MARKERS['range']])) + self.stream.write(struct.pack(">iii", obj.start, obj.stop, obj.step)) + + def _write_custom_object(self, obj: Any): + """Serialize custom objects using their __dict__.""" + self.stream.write(bytes([self.TYPE_MARKERS['custom_object']])) + + # Write class name + class_name = f"{obj.__class__.__module__}.{obj.__class__.__name__}" + self._write_string(class_name) + + # Write object data + if hasattr(obj, '__dict__'): + self._write_dict(obj.__dict__) else: - raise TypeError(f"Unsupported type for Hessian encoding: {type(obj)}") + # Fallback: try to convert to string + self._write_string(str(obj)) class Hessian2Input: + """Enhanced manual Hessian deserializer supporting additional Python types.""" + def __init__(self, stream: io.BytesIO): self.stream = stream + self.type_registry: Dict[str, Type] = {} + + def register_type(self, class_path: str, cls: Type): + """Register a custom type for deserialization.""" + self.type_registry[class_path] = cls def _read(self, n: int) -> bytes: data = self.stream.read(n) @@ -62,16 +203,16 @@ def _read_byte(self) -> int: return data[0] def read_object(self) -> Any: + """Read any Python object from Hessian format.""" tag = self._read_byte() - + + # Standard Hessian types if tag == ord('N'): return None - elif tag == ord('T'): return True elif tag == ord('F'): return False - elif tag == ord('I'): return struct.unpack(">i", self._read(4))[0] elif 0x90 <= tag <= 0xbf: @@ -80,41 +221,142 @@ def read_object(self) -> Any: return ((tag - 0xc8) << 8) + self._read_byte() elif 0xd0 <= tag <= 0xd7: return ((tag - 0xd4) << 16) + (self._read_byte() << 8) + self._read_byte() - elif tag == ord('D'): return struct.unpack(">d", self._read(8))[0] - elif tag == ord('S'): length = struct.unpack(">H", self._read(2))[0] return self._read(length).decode("utf-8") elif 0x00 <= tag <= 0x1f: return self._read(tag).decode("utf-8") - elif tag == ord('B'): length = struct.unpack(">H", self._read(2))[0] return self._read(length) - elif tag == ord('V'): - arr = [] - while True: - peek = self.stream.read(1) - if peek == b'Z': - break - self.stream.seek(-1, 1) - arr.append(self.read_object()) - return arr - + return self._read_list() elif tag == ord('M'): - d = {} - while True: - peek = self.stream.read(1) - if peek == b'Z': - break - self.stream.seek(-1, 1) - k = self.read_object() - v = self.read_object() - d[k] = v - return d + return self._read_dict() + elif tag == ord('d'): + # Standard datetime + timestamp_ms = struct.unpack(">q", self._read(8))[0] + return datetime.fromtimestamp(timestamp_ms / 1000) + # Custom types + elif tag == Hessian2Output.TYPE_MARKERS['set']: + return self._read_set() + elif tag == Hessian2Output.TYPE_MARKERS['frozenset']: + return self._read_frozenset() + elif tag == Hessian2Output.TYPE_MARKERS['tuple']: + return self._read_tuple() + elif tag == Hessian2Output.TYPE_MARKERS['decimal']: + return self._read_decimal() + elif tag == Hessian2Output.TYPE_MARKERS['uuid']: + return self._read_uuid() + elif tag == Hessian2Output.TYPE_MARKERS['date']: + return self._read_date() + elif tag == Hessian2Output.TYPE_MARKERS['time']: + return self._read_time() + elif tag == Hessian2Output.TYPE_MARKERS['complex']: + return self._read_complex() + elif tag == Hessian2Output.TYPE_MARKERS['range']: + return self._read_range() + elif tag == Hessian2Output.TYPE_MARKERS['bytearray']: + return self._read_bytearray() + elif tag == Hessian2Output.TYPE_MARKERS['custom_object']: + return self._read_custom_object() else: raise ValueError(f"Unknown Hessian tag: {hex(tag)}") + + def _read_list(self) -> list: + arr = [] + while True: + peek = self.stream.read(1) + if peek == b'Z': + break + self.stream.seek(-1, 1) + arr.append(self.read_object()) + return arr + + def _read_dict(self) -> dict: + d = {} + while True: + peek = self.stream.read(1) + if peek == b'Z': + break + self.stream.seek(-1, 1) + k = self.read_object() + v = self.read_object() + d[k] = v + return d + + def _read_set(self) -> set: + length = struct.unpack(">I", self._read(4))[0] + items = set() + for _ in range(length): + items.add(self.read_object()) + return items + + def _read_frozenset(self) -> frozenset: + length = struct.unpack(">I", self._read(4))[0] + items = [] + for _ in range(length): + items.append(self.read_object()) + return frozenset(items) + + def _read_tuple(self) -> tuple: + length = struct.unpack(">I", self._read(4))[0] + items = [] + for _ in range(length): + items.append(self.read_object()) + return tuple(items) + + def _read_decimal(self) -> Decimal: + string_val = self.read_object() # Read the string representation + return Decimal(string_val) + + def _read_uuid(self) -> uuid.UUID: + uuid_bytes = self.read_object() # Read the bytes + return uuid.UUID(bytes=uuid_bytes) + + def _read_date(self) -> date: + year, month, day = struct.unpack(">HBB", self._read(4)) + return date(year, month, day) + + def _read_time(self) -> time: + hour, minute, second, microsecond = struct.unpack(">BBBI", self._read(7)) + return time(hour, minute, second, microsecond) + + def _read_complex(self) -> complex: + real, imag = struct.unpack(">dd", self._read(16)) + return complex(real, imag) + + def _read_range(self) -> range: + start, stop, step = struct.unpack(">iii", self._read(12)) + return range(start, stop, step) + + def _read_bytearray(self) -> bytearray: + bytes_data = self.read_object() # Read the bytes + return bytearray(bytes_data) + + def _read_custom_object(self) -> Any: + """Deserialize custom objects.""" + class_name = self.read_object() # Read class name + obj_data = self.read_object() # Read object data + + # Try to reconstruct the object + if class_name in self.type_registry: + cls = self.type_registry[class_name] + if isinstance(obj_data, dict): + # Try to create object and set attributes + try: + obj = cls.__new__(cls) + obj.__dict__.update(obj_data) + return obj + except Exception: + # Fallback: return dict with class info + return {"__class__": class_name, "__data__": obj_data} + else: + # Data is not a dict, return as-is with class info + return {"__class__": class_name, "__data__": obj_data} + else: + # Unknown class, return dict representation + return {"__class__": class_name, "__data__": obj_data} From ed8c30be51cc949dd96d2bc6459afabac4727a50 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Mon, 1 Sep 2025 18:42:18 +0000 Subject: [PATCH 13/40] addding the neccasy license --- src/dubbo/codec/hessian_codec/__init__.py | 19 +++++++++++++++++++ .../hessian_codec/hessian_codec_handler.py | 16 ++++++++++++++++ .../codec/hessian_codec/manual_hessian.py | 16 ++++++++++++++++ 3 files changed, 51 insertions(+) create mode 100644 src/dubbo/codec/hessian_codec/__init__.py diff --git a/src/dubbo/codec/hessian_codec/__init__.py b/src/dubbo/codec/hessian_codec/__init__.py new file mode 100644 index 0000000..f40d835 --- /dev/null +++ b/src/dubbo/codec/hessian_codec/__init__.py @@ -0,0 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .hessian_codec_handler import HessianTransportCodec, HessianTransportDecoder, HessianTransportEncoder + +__all__ = ["HessianTransportCodec", "HessianTransportDecoder", "HessianTransportEncoder"] \ No newline at end of file diff --git a/src/dubbo/codec/hessian_codec/hessian_codec_handler.py b/src/dubbo/codec/hessian_codec/hessian_codec_handler.py index 851b1ab..5e06abc 100644 --- a/src/dubbo/codec/hessian_codec/hessian_codec_handler.py +++ b/src/dubbo/codec/hessian_codec/hessian_codec_handler.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import io import logging from typing import Any, List, Dict, Optional, Union, Type diff --git a/src/dubbo/codec/hessian_codec/manual_hessian.py b/src/dubbo/codec/hessian_codec/manual_hessian.py index c0906fb..4aefcdd 100644 --- a/src/dubbo/codec/hessian_codec/manual_hessian.py +++ b/src/dubbo/codec/hessian_codec/manual_hessian.py @@ -1,3 +1,19 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import io import struct from typing import Any, Dict, Type From da70485a36acc6d1170e6e208dff68f6595d630f Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Tue, 2 Sep 2025 16:38:09 +0000 Subject: [PATCH 14/40] remove the hessian implementation --- src/dubbo/codec/hessian_codec/__init__.py | 19 - .../hessian_codec/hessian_codec_handler.py | 323 --------------- .../codec/hessian_codec/manual_hessian.py | 378 ------------------ 3 files changed, 720 deletions(-) delete mode 100644 src/dubbo/codec/hessian_codec/__init__.py delete mode 100644 src/dubbo/codec/hessian_codec/hessian_codec_handler.py delete mode 100644 src/dubbo/codec/hessian_codec/manual_hessian.py diff --git a/src/dubbo/codec/hessian_codec/__init__.py b/src/dubbo/codec/hessian_codec/__init__.py deleted file mode 100644 index f40d835..0000000 --- a/src/dubbo/codec/hessian_codec/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from .hessian_codec_handler import HessianTransportCodec, HessianTransportDecoder, HessianTransportEncoder - -__all__ = ["HessianTransportCodec", "HessianTransportDecoder", "HessianTransportEncoder"] \ No newline at end of file diff --git a/src/dubbo/codec/hessian_codec/hessian_codec_handler.py b/src/dubbo/codec/hessian_codec/hessian_codec_handler.py deleted file mode 100644 index 5e06abc..0000000 --- a/src/dubbo/codec/hessian_codec/hessian_codec_handler.py +++ /dev/null @@ -1,323 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import logging -from typing import Any, List, Dict, Optional, Union, Type -from abc import ABC -from datetime import datetime, date, time -from decimal import Decimal -import uuid - -# Manual serializer for types pyhessian can't handle -from .manual_hessian import Hessian2Output as output, Hessian2Input as input - -try: - from pyhessian import Hessian2Input, Hessian2Output - _HAS_PYHESSIAN = True -except ImportError: - _HAS_PYHESSIAN = False - - -class HessianTypeError(Exception): - """Exception raised for type validation errors in Hessian RPC.""" - pass - - -class HessianRpcError(Exception): - """Base exception for Hessian RPC errors.""" - pass - - -class TypeValidator: - """Type validation utilities for Hessian RPC.""" - - @staticmethod - def validate_type(value: Any, expected_type: Union[str, Type]) -> bool: - """Validate if a value matches the expected type.""" - if isinstance(expected_type, str): - return TypeValidator._validate_string_type(value, expected_type) - else: - return isinstance(value, expected_type) - - @staticmethod - def _validate_string_type(value: Any, type_string: str) -> bool: - """Validate value against string type specification.""" - type_mapping = { - 'str': str, 'int': int, 'float': float, 'bool': bool, - 'list': list, 'dict': dict, 'bytes': bytes, 'tuple': tuple, - 'set': set, 'frozenset': frozenset, 'datetime': datetime, - 'date': date, 'time': time, 'decimal': Decimal, 'uuid': uuid.UUID, - 'none': type(None), 'any': object - } - - if type_string.lower() in type_mapping: - expected_type = type_mapping[type_string.lower()] - return isinstance(value, expected_type) - - if '[' in type_string and ']' in type_string: - base_type = type_string.split('[')[0].lower() - if base_type == 'list': - return isinstance(value, list) - elif base_type == 'dict': - return isinstance(value, dict) - - return True - - @staticmethod - def validate_parameters(args: tuple, param_types: List[Union[str, Type]]) -> None: - """Validate method parameters against expected types.""" - if len(args) != len(param_types): - raise HessianTypeError( - f"Parameter count mismatch: expected {len(param_types)}, got {len(args)}" - ) - - for i, (arg, expected_type) in enumerate(zip(args, param_types)): - if not TypeValidator.validate_type(arg, expected_type): - raise HessianTypeError( - f"Parameter {i} type mismatch: expected {expected_type}, got {type(arg).__name__}" - ) - - -class TypeProvider: - """Provides type information and serialization hints for custom objects.""" - - def __init__(self): - self._type_registry: Dict[str, Type] = {} - - def register_type(self, type_name: str, type_class: Type) -> None: - """Register a custom type for serialization.""" - self._type_registry[type_name] = type_class - - def get_type(self, type_name: str) -> Optional[Type]: - """Get a registered type by name.""" - return self._type_registry.get(type_name) - - -class HybridHessianHandler: - """Handles decision-making between pyhessian and manual serialization.""" - - # Types that pyhessian typically can't handle well - MANUAL_TYPES = { - set, frozenset, tuple, Decimal, uuid.UUID, date, time, - complex, range, memoryview, bytearray - } - - # Custom objects (non-builtin types) - BUILTIN_TYPES = { - str, int, float, bool, list, dict, bytes, type(None), - datetime # pyhessian can handle datetime - } - - @classmethod - def should_use_manual(cls, obj: Any) -> bool: - """Determine if an object should use manual serialization.""" - obj_type = type(obj) - - # Check if it's a type we know needs manual handling - if obj_type in cls.MANUAL_TYPES: - return True - - # Check if it's a custom class (not a builtin) - if obj_type not in cls.BUILTIN_TYPES and hasattr(obj, '__dict__'): - return True - - # For containers, check their contents - if isinstance(obj, (list, tuple)): - return any(cls.should_use_manual(item) for item in obj) - elif isinstance(obj, dict): - return any(cls.should_use_manual(k) or cls.should_use_manual(v) - for k, v in obj.items()) - elif isinstance(obj, (set, frozenset)): - return True # sets always need manual handling - - return False - - @classmethod - def contains_manual_types(cls, args: tuple) -> bool: - """Check if any argument requires manual serialization.""" - return any(cls.should_use_manual(arg) for arg in args) - - -class HessianTransportEncoder: - """Encodes Python objects using hybrid pyhessian/manual approach.""" - - def __init__(self, parameter_types: Optional[List[Union[str, Type]]] = None): - self.parameter_types = parameter_types or [] - self._type_provider: Optional[TypeProvider] = None - self._logger = logging.getLogger(__name__) - - def encode(self, arguments: tuple) -> bytes: - """Encode arguments using the most appropriate serializer.""" - if self.parameter_types: - TypeValidator.validate_parameters(arguments, self.parameter_types) - - try: - # Decide which serializer to use - use_manual = not _HAS_PYHESSIAN or HybridHessianHandler.contains_manual_types(arguments) - - if use_manual: - return self._encode_manual(arguments) - else: - return self._encode_pyhessian(arguments) - - except Exception as e: - self._logger.error(f"Encoding error: {e}") - # Fallback to manual if pyhessian fails - if not use_manual and _HAS_PYHESSIAN: - self._logger.info("Falling back to manual serialization") - try: - return self._encode_manual(arguments) - except Exception as fallback_error: - raise HessianRpcError(f"Both serializers failed. Last error: {fallback_error}") - raise HessianRpcError(f"Failed to encode parameters: {e}") - - def _encode_pyhessian(self, arguments: tuple) -> bytes: - """Encode using pyhessian library.""" - output_stream = io.BytesIO() - hessian_output = Hessian2Output(output_stream) - - for arg in arguments: - hessian_output.write_object(arg) - - return output_stream.getvalue() - - def _encode_manual(self, arguments: tuple) -> bytes: - """Encode using manual serializer.""" - output_stream = io.BytesIO() - hessian_output = output(output_stream) - - for arg in arguments: - hessian_output.write_object(arg) - - return output_stream.getvalue() - - def register_type_provider(self, provider: TypeProvider) -> None: - """Register a type provider for custom type handling.""" - self._type_provider = provider - - -class HessianTransportDecoder: - """Decodes Hessian binary format using hybrid approach.""" - - def __init__(self, target_type: Optional[Union[str, Type]] = None, prefer_manual: bool = False): - self.target_type = target_type - self.prefer_manual = prefer_manual - self._type_provider: Optional[TypeProvider] = None - self._logger = logging.getLogger(__name__) - - def decode(self, data: bytes) -> Any: - """Decode data using the most appropriate deserializer.""" - try: - # Try the preferred method first - if self.prefer_manual or not _HAS_PYHESSIAN: - return self._decode_manual(data) - else: - return self._decode_pyhessian(data) - - except Exception as e: - self._logger.error(f"Decoding error: {e}") - # Fallback to the other method - if not self.prefer_manual and _HAS_PYHESSIAN: - self._logger.info("Falling back to manual deserialization") - try: - return self._decode_manual(data) - except Exception as fallback_error: - raise HessianRpcError(f"Both deserializers failed. Last error: {fallback_error}") - elif self.prefer_manual and _HAS_PYHESSIAN: - self._logger.info("Falling back to pyhessian deserialization") - try: - return self._decode_pyhessian(data) - except Exception as fallback_error: - raise HessianRpcError(f"Both deserializers failed. Last error: {fallback_error}") - raise HessianRpcError(f"Failed to decode data: {e}") - - def _decode_pyhessian(self, data: bytes) -> Any: - """Decode using pyhessian library.""" - input_stream = io.BytesIO(data) - hessian_input = Hessian2Input(input_stream) - result = hessian_input.read_object() - - if self.target_type: - if not TypeValidator.validate_type(result, self.target_type): - raise HessianTypeError( - f"Return type mismatch: expected {self.target_type}, got {type(result).__name__}" - ) - - return result - - def _decode_manual(self, data: bytes) -> Any: - """Decode using manual deserializer.""" - input_stream = io.BytesIO(data) - hessian_input = input(input_stream) - result = hessian_input.read_object() - - if self.target_type: - if not TypeValidator.validate_type(result, self.target_type): - raise HessianTypeError( - f"Return type mismatch: expected {self.target_type}, got {type(result).__name__}" - ) - - return result - - def register_type_provider(self, provider: TypeProvider) -> None: - """Register a type provider for custom type handling.""" - self._type_provider = provider - - -class HessianTransportCodec(ABC): - """High-level encoder/decoder wrapper with hybrid serialization.""" - - def __init__(self, parameter_types: Optional[List[Union[str, Type]]] = None, - return_type: Optional[Union[str, Type]] = None, - prefer_manual: bool = False): - self.parameter_types = parameter_types or [] - self.return_type = return_type - self.prefer_manual = prefer_manual - self._encoder = HessianTransportEncoder(parameter_types) - self._decoder = HessianTransportDecoder(return_type, prefer_manual) - self._logger = logging.getLogger(__name__) - - def encode_parameters(self, *arguments) -> bytes: - """Encode method parameters to Hessian binary format.""" - return self._encoder.encode(arguments) - - def decode_return_value(self, data: bytes) -> Any: - """Decode return value from Hessian binary format.""" - return self._decoder.decode(data) - - def validate_call(self, *arguments) -> None: - """Validate method call parameters without encoding.""" - if self.parameter_types: - TypeValidator.validate_parameters(arguments, self.parameter_types) - - def register_type_provider(self, provider: TypeProvider) -> None: - """Register a type provider for both encoder and decoder.""" - self._encoder.register_type_provider(provider) - self._decoder.register_type_provider(provider) - - def get_encoder(self) -> HessianTransportEncoder: - """Get the encoder instance.""" - return self._encoder - - def get_decoder(self) -> HessianTransportDecoder: - """Get the decoder instance.""" - return self._decoder - - def set_manual_preference(self, prefer_manual: bool)-> None: - """Set preference for manual serialization.""" - self.prefer_manual = prefer_manual - self._decoder.prefer_manual = prefer_manual diff --git a/src/dubbo/codec/hessian_codec/manual_hessian.py b/src/dubbo/codec/hessian_codec/manual_hessian.py deleted file mode 100644 index 4aefcdd..0000000 --- a/src/dubbo/codec/hessian_codec/manual_hessian.py +++ /dev/null @@ -1,378 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import io -import struct -from typing import Any, Dict, Type -from datetime import datetime, date, time -from decimal import Decimal -import uuid - - -class Hessian2Output: - - TYPE_MARKERS = { - 'set': 0xF0, - 'frozenset': 0xF1, - 'tuple': 0xF2, - 'decimal': 0xF3, - 'uuid': 0xF4, - 'date': 0xF5, - 'time': 0xF6, - 'complex': 0xF7, - 'range': 0xF8, - 'bytearray': 0xF9, - 'custom_object': 0xFA - } - - def __init__(self, stream: io.BytesIO): - self.stream = stream - - def write_object(self, obj: Any): - """Write any Python object to Hessian format.""" - if obj is None: - self._write_null() - elif isinstance(obj, bool): - self._write_bool(obj) - elif isinstance(obj, int): - self._write_int(obj) - elif isinstance(obj, float): - self._write_float(obj) - elif isinstance(obj, str): - self._write_string(obj) - elif isinstance(obj, bytes): - self._write_bytes(obj) - elif isinstance(obj, bytearray): - self._write_bytearray(obj) - elif isinstance(obj, list): - self._write_list(obj) - elif isinstance(obj, tuple): - self._write_tuple(obj) - elif isinstance(obj, dict): - self._write_dict(obj) - elif isinstance(obj, set): - self._write_set(obj) - elif isinstance(obj, frozenset): - self._write_frozenset(obj) - elif isinstance(obj, Decimal): - self._write_decimal(obj) - elif isinstance(obj, uuid.UUID): - self._write_uuid(obj) - elif isinstance(obj, datetime): - self._write_datetime(obj) - elif isinstance(obj, date): - self._write_date(obj) - elif isinstance(obj, time): - self._write_time(obj) - elif isinstance(obj, complex): - self._write_complex(obj) - elif isinstance(obj, range): - self._write_range(obj) - else: - self._write_custom_object(obj) - - def _write_null(self): - self.stream.write(b'N') - - def _write_bool(self, obj: bool): - self.stream.write(b'T' if obj else b'F') - - def _write_int(self, obj: int): - if -16 <= obj <= 47: - self.stream.write(bytes([0x90 + obj])) - elif -2048 <= obj <= 2047: - self.stream.write(bytes([0xc8 + (obj >> 8), obj & 0xff])) - else: - self.stream.write(b'I' + struct.pack(">i", obj)) - - def _write_float(self, obj: float): - self.stream.write(b'D' + struct.pack(">d", obj)) - - def _write_string(self, obj: str): - encoded = obj.encode("utf-8") - length = len(encoded) - if length <= 31: - self.stream.write(bytes([length]) + encoded) - else: - self.stream.write(b'S' + struct.pack(">H", length) + encoded) - - def _write_bytes(self, obj: bytes): - length = len(obj) - self.stream.write(b'B' + struct.pack(">H", length) + obj) - - def _write_bytearray(self, obj: bytearray): - self.stream.write(bytes([self.TYPE_MARKERS['bytearray']])) - self._write_bytes(bytes(obj)) - - def _write_list(self, obj: list): - self.stream.write(b'V') - for item in obj: - self.write_object(item) - self.stream.write(b'Z') - - def _write_tuple(self, obj: tuple): - self.stream.write(bytes([self.TYPE_MARKERS['tuple']])) - self.stream.write(struct.pack(">I", len(obj))) - for item in obj: - self.write_object(item) - - def _write_dict(self, obj: dict): - self.stream.write(b'M') - for k, v in obj.items(): - self.write_object(k) - self.write_object(v) - self.stream.write(b'Z') - - def _write_set(self, obj: set): - self.stream.write(bytes([self.TYPE_MARKERS['set']])) - self.stream.write(struct.pack(">I", len(obj))) - for item in obj: - self.write_object(item) - - def _write_frozenset(self, obj: frozenset): - self.stream.write(bytes([self.TYPE_MARKERS['frozenset']])) - self.stream.write(struct.pack(">I", len(obj))) - for item in obj: - self.write_object(item) - - def _write_decimal(self, obj: Decimal): - self.stream.write(bytes([self.TYPE_MARKERS['decimal']])) - self._write_string(str(obj)) - - def _write_uuid(self, obj: uuid.UUID): - self.stream.write(bytes([self.TYPE_MARKERS['uuid']])) - self._write_bytes(obj.bytes) - - def _write_datetime(self, obj: datetime): - # Use standard Hessian date format - timestamp = int(obj.timestamp() * 1000) - self.stream.write(b'd' + struct.pack(">q", timestamp)) - - def _write_date(self, obj: date): - self.stream.write(bytes([self.TYPE_MARKERS['date']])) - # Store as year, month, day - self.stream.write(struct.pack(">HBB", obj.year, obj.month, obj.day)) - - def _write_time(self, obj: time): - self.stream.write(bytes([self.TYPE_MARKERS['time']])) - # Store as hour, minute, second, microsecond - self.stream.write(struct.pack(">BBBI", obj.hour, obj.minute, obj.second, obj.microsecond)) - - def _write_complex(self, obj: complex): - self.stream.write(bytes([self.TYPE_MARKERS['complex']])) - self.stream.write(struct.pack(">dd", obj.real, obj.imag)) - - def _write_range(self, obj: range): - self.stream.write(bytes([self.TYPE_MARKERS['range']])) - self.stream.write(struct.pack(">iii", obj.start, obj.stop, obj.step)) - - def _write_custom_object(self, obj: Any): - """Serialize custom objects using their __dict__.""" - self.stream.write(bytes([self.TYPE_MARKERS['custom_object']])) - - # Write class name - class_name = f"{obj.__class__.__module__}.{obj.__class__.__name__}" - self._write_string(class_name) - - # Write object data - if hasattr(obj, '__dict__'): - self._write_dict(obj.__dict__) - else: - # Fallback: try to convert to string - self._write_string(str(obj)) - - -class Hessian2Input: - """Enhanced manual Hessian deserializer supporting additional Python types.""" - - def __init__(self, stream: io.BytesIO): - self.stream = stream - self.type_registry: Dict[str, Type] = {} - - def register_type(self, class_path: str, cls: Type): - """Register a custom type for deserialization.""" - self.type_registry[class_path] = cls - - def _read(self, n: int) -> bytes: - data = self.stream.read(n) - if len(data) != n: - raise ValueError(f"Expected {n} bytes, got {len(data)}") - return data - - def _read_byte(self) -> int: - data = self.stream.read(1) - if not data: - raise ValueError("Unexpected end of stream") - return data[0] - - def read_object(self) -> Any: - """Read any Python object from Hessian format.""" - tag = self._read_byte() - - # Standard Hessian types - if tag == ord('N'): - return None - elif tag == ord('T'): - return True - elif tag == ord('F'): - return False - elif tag == ord('I'): - return struct.unpack(">i", self._read(4))[0] - elif 0x90 <= tag <= 0xbf: - return tag - 0x90 - elif 0xc0 <= tag <= 0xcf: - return ((tag - 0xc8) << 8) + self._read_byte() - elif 0xd0 <= tag <= 0xd7: - return ((tag - 0xd4) << 16) + (self._read_byte() << 8) + self._read_byte() - elif tag == ord('D'): - return struct.unpack(">d", self._read(8))[0] - elif tag == ord('S'): - length = struct.unpack(">H", self._read(2))[0] - return self._read(length).decode("utf-8") - elif 0x00 <= tag <= 0x1f: - return self._read(tag).decode("utf-8") - elif tag == ord('B'): - length = struct.unpack(">H", self._read(2))[0] - return self._read(length) - elif tag == ord('V'): - return self._read_list() - elif tag == ord('M'): - return self._read_dict() - elif tag == ord('d'): - # Standard datetime - timestamp_ms = struct.unpack(">q", self._read(8))[0] - return datetime.fromtimestamp(timestamp_ms / 1000) - - # Custom types - elif tag == Hessian2Output.TYPE_MARKERS['set']: - return self._read_set() - elif tag == Hessian2Output.TYPE_MARKERS['frozenset']: - return self._read_frozenset() - elif tag == Hessian2Output.TYPE_MARKERS['tuple']: - return self._read_tuple() - elif tag == Hessian2Output.TYPE_MARKERS['decimal']: - return self._read_decimal() - elif tag == Hessian2Output.TYPE_MARKERS['uuid']: - return self._read_uuid() - elif tag == Hessian2Output.TYPE_MARKERS['date']: - return self._read_date() - elif tag == Hessian2Output.TYPE_MARKERS['time']: - return self._read_time() - elif tag == Hessian2Output.TYPE_MARKERS['complex']: - return self._read_complex() - elif tag == Hessian2Output.TYPE_MARKERS['range']: - return self._read_range() - elif tag == Hessian2Output.TYPE_MARKERS['bytearray']: - return self._read_bytearray() - elif tag == Hessian2Output.TYPE_MARKERS['custom_object']: - return self._read_custom_object() - else: - raise ValueError(f"Unknown Hessian tag: {hex(tag)}") - - def _read_list(self) -> list: - arr = [] - while True: - peek = self.stream.read(1) - if peek == b'Z': - break - self.stream.seek(-1, 1) - arr.append(self.read_object()) - return arr - - def _read_dict(self) -> dict: - d = {} - while True: - peek = self.stream.read(1) - if peek == b'Z': - break - self.stream.seek(-1, 1) - k = self.read_object() - v = self.read_object() - d[k] = v - return d - - def _read_set(self) -> set: - length = struct.unpack(">I", self._read(4))[0] - items = set() - for _ in range(length): - items.add(self.read_object()) - return items - - def _read_frozenset(self) -> frozenset: - length = struct.unpack(">I", self._read(4))[0] - items = [] - for _ in range(length): - items.append(self.read_object()) - return frozenset(items) - - def _read_tuple(self) -> tuple: - length = struct.unpack(">I", self._read(4))[0] - items = [] - for _ in range(length): - items.append(self.read_object()) - return tuple(items) - - def _read_decimal(self) -> Decimal: - string_val = self.read_object() # Read the string representation - return Decimal(string_val) - - def _read_uuid(self) -> uuid.UUID: - uuid_bytes = self.read_object() # Read the bytes - return uuid.UUID(bytes=uuid_bytes) - - def _read_date(self) -> date: - year, month, day = struct.unpack(">HBB", self._read(4)) - return date(year, month, day) - - def _read_time(self) -> time: - hour, minute, second, microsecond = struct.unpack(">BBBI", self._read(7)) - return time(hour, minute, second, microsecond) - - def _read_complex(self) -> complex: - real, imag = struct.unpack(">dd", self._read(16)) - return complex(real, imag) - - def _read_range(self) -> range: - start, stop, step = struct.unpack(">iii", self._read(12)) - return range(start, stop, step) - - def _read_bytearray(self) -> bytearray: - bytes_data = self.read_object() # Read the bytes - return bytearray(bytes_data) - - def _read_custom_object(self) -> Any: - """Deserialize custom objects.""" - class_name = self.read_object() # Read class name - obj_data = self.read_object() # Read object data - - # Try to reconstruct the object - if class_name in self.type_registry: - cls = self.type_registry[class_name] - if isinstance(obj_data, dict): - # Try to create object and set attributes - try: - obj = cls.__new__(cls) - obj.__dict__.update(obj_data) - return obj - except Exception: - # Fallback: return dict with class info - return {"__class__": class_name, "__data__": obj_data} - else: - # Data is not a dict, return as-is with class info - return {"__class__": class_name, "__data__": obj_data} - else: - # Unknown class, return dict representation - return {"__class__": class_name, "__data__": obj_data} From fc6e5dd6313432b355f82d2b537a563297e08377 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Wed, 3 Sep 2025 15:34:39 +0000 Subject: [PATCH 15/40] fixing the issue --- src/dubbo/client.py | 120 ++++++++++++++---------------- src/dubbo/codec/dubbo_codec.py | 132 +++++++++++++++++++++------------ 2 files changed, 139 insertions(+), 113 deletions(-) diff --git a/src/dubbo/client.py b/src/dubbo/client.py index f79d4af..2862228 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -8,15 +8,14 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import threading -import inspect -from typing import Optional, Callable, List, Type, Any, get_type_hints +from typing import Optional, List, Type from dubbo.bootstrap import Dubbo from dubbo.classes import MethodDescriptor @@ -91,78 +90,33 @@ def _initialize(self): self._initialized = True - @classmethod - def _infer_types_from_interface(cls, interface: Callable) -> tuple: - """ - Infer method name, parameter types, and return type from a callable. - """ - try: - type_hints = get_type_hints(interface) - sig = inspect.signature(interface) - method_name = interface.__name__ - params = list(sig.parameters.values()) - - # skip 'self' for bound methods - if params and params[0].name == "self": - params = params[1:] - - param_types = [type_hints.get(p.name, Any) for p in params] - return_type = type_hints.get("return", Any) - - return method_name, param_types, return_type - except Exception: - return interface.__name__, [Any], Any - def _create_rpc_callable( self, rpc_type: str, - interface: Optional[Callable] = None, - method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, + method_name: str, + params_types: List[Type], + return_type: Type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, - default_method_name: str = "rpc_call", ) -> RpcCallable: """ Create RPC callable with the specified type. """ - if interface is None and method_name is None: - raise ValueError("Either 'interface' or 'method_name' must be provided") - - # Start with explicit values - m_name = method_name - p_types = params_types - r_type = return_type - - # Infer from interface if needed - if interface: - if p_types is None or r_type is None or m_name is None: - inf_name, inf_params, inf_return = self._infer_types_from_interface( - interface - ) - m_name = m_name or inf_name - p_types = p_types or inf_params - r_type = r_type or inf_return - - # Fallback to default - m_name = m_name or default_method_name - # Determine serializers if request_serializer and response_deserializer: req_ser = request_serializer res_deser = response_deserializer else: req_ser, res_deser = DubboTransportService.create_serialization_functions( - codec or "json", # fallback to json - parameter_types=p_types, - return_type=r_type, + codec or "json", + parameter_types=params_types, + return_type=return_type, ) # Create MethodDescriptor descriptor = MethodDescriptor( - method_name=m_name, + method_name=method_name, arg_serialization=(req_ser, None), return_serialization=(None, res_deser), rpc_type=rpc_type, @@ -170,28 +124,64 @@ def _create_rpc_callable( return self._callable(descriptor) - def unary(self, **kwargs) -> RpcCallable: + def unary( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + **kwargs + ) -> RpcCallable: return self._create_rpc_callable( - rpc_type=RpcTypes.UNARY.value, default_method_name="unary", **kwargs + rpc_type=RpcTypes.UNARY.value, + method_name=method_name, + params_types=params_types, + return_type=return_type, + **kwargs, ) - def client_stream(self, **kwargs) -> RpcCallable: + def client_stream( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + **kwargs + ) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.CLIENT_STREAM.value, - default_method_name="client_stream", + method_name=method_name, + params_types=params_types, + return_type=return_type, **kwargs, ) - def server_stream(self, **kwargs) -> RpcCallable: + def server_stream( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + **kwargs + ) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.SERVER_STREAM.value, - default_method_name="server_stream", + method_name=method_name, + params_types=params_types, + return_type=return_type, **kwargs, ) - def bi_stream(self, **kwargs) -> RpcCallable: + def bi_stream( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + **kwargs + ) -> RpcCallable: return self._create_rpc_callable( - rpc_type=RpcTypes.BI_STREAM.value, default_method_name="bi_stream", **kwargs + rpc_type=RpcTypes.BI_STREAM.value, + method_name=method_name, + params_types=params_types, + return_type=return_type, + **kwargs, ) def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: @@ -203,4 +193,4 @@ def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: method_descriptor.get_method_name() ) url.attributes[common_constants.METHOD_DESCRIPTOR_KEY] = method_descriptor - return self._callable_factory.get_callable(self._invoker, url) + return self._callable_factory.get_callable(self._invoker, url) \ No newline at end of file diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index 8758684..a154de7 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -14,12 +14,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Type, Optional, Callable, List, Dict +from typing import Any, Type, Optional, Callable, List, Dict, Tuple from dataclasses import dataclass import inspect +import logging + +logger = logging.getLogger(__name__) -from dubbo.classes import CodecHelper -from dubbo.codec.json_codec import JsonTransportCodec, JsonTransportEncoder, JsonTransportDecoder @dataclass class ParameterDescriptor: @@ -40,88 +41,120 @@ class MethodDescriptor: documentation: Optional[str] = None -class DubboTransportService: - """Enhanced Dubbo transport service with robust type handling""" +class DubboSerializationService: + """Dubbo serialization service with robust type handling""" @staticmethod def create_transport_codec(transport_type: str = 'json', parameter_types: List[Type] = None, return_type: Type = None, **codec_options): """Create transport codec with enhanced parameter structure""" - if transport_type == 'json': - return JsonTransportCodec( - parameter_types=parameter_types, - return_type=return_type, - **codec_options - ) - else: + + try: from dubbo.extension.extension_loader import ExtensionLoader - Codec = CodecHelper.get_class() - codec_class = ExtensionLoader().get_extension(Codec, transport_type) + from dubbo.classes import CodecHelper + + codec_class = ExtensionLoader().get_extension(CodecHelper.get_class(), transport_type) return codec_class( - parameter_types=parameter_types, + parameter_types=parameter_types or [], return_type=return_type, **codec_options ) + except ImportError as e: + logger.error(f"Failed to import required modules: {e}") + raise + except Exception as e: + logger.error(f"Failed to create transport codec: {e}") + raise @staticmethod def create_encoder_decoder_pair(transport_type: str, parameter_types: List[Type] = None, - return_type: Type = None, **codec_options) -> tuple[any,any]: + return_type: Type = None, **codec_options) -> Tuple[Any, Any]: """Create separate encoder and decoder instances""" - if transport_type == 'json': - parameter_encoder = JsonTransportEncoder(parameter_types=parameter_types, **codec_options) - return_decoder = JsonTransportDecoder(target_type=return_type, **codec_options) - return parameter_encoder, return_decoder - else: - from dubbo.extension.extension_loader import ExtensionLoader - Codec = CodecHelper.get_class() - codec_class = ExtensionLoader().get_extension(Codec, transport_type) - - codec_instance = codec_class( + try: + codec_instance = DubboSerializationService.create_transport_codec( + transport_type=transport_type, parameter_types=parameter_types, return_type=return_type, **codec_options ) - return codec_instance.get_encoder(), codec_instance.get_decoder() + + encoder = codec_instance.get_encoder() + decoder = codec_instance.get_decoder() + + if encoder is None or decoder is None: + raise ValueError(f"Codec for transport type '{transport_type}' returned None encoder/decoder") + + return encoder, decoder + + except Exception as e: + logger.error(f"Failed to create encoder/decoder pair: {e}") + raise @staticmethod def create_serialization_functions(transport_type: str, parameter_types: List[Type] = None, - return_type: Type = None, **codec_options) -> tuple[Callable, Callable]: + return_type: Type = None, **codec_options) -> Tuple[Callable, Callable]: """Create serializer and deserializer functions for RPC (backward compatibility)""" - parameter_encoder, return_decoder = DubboTransportService.create_encoder_decoder_pair( - transport_type=transport_type, - parameter_types=parameter_types, - return_type=return_type, - **codec_options - ) - - def serialize_method_parameters(*args) -> bytes: - return parameter_encoder.encode(args) + try: + parameter_encoder, return_decoder = DubboSerializationService.create_encoder_decoder_pair( + transport_type=transport_type, + parameter_types=parameter_types, + return_type=return_type, + **codec_options + ) - def deserialize_method_return(data: bytes): - return return_decoder.decode(data) - - return serialize_method_parameters, deserialize_method_return + def serialize_method_parameters(*args) -> bytes: + try: + return parameter_encoder.encode(args) + except Exception as e: + logger.error(f"Failed to serialize parameters: {e}") + raise + + def deserialize_method_return(data: bytes): + if not isinstance(data, bytes): + raise TypeError(f"Expected bytes, got {type(data)}") + try: + return return_decoder.decode(data) + except Exception as e: + logger.error(f"Failed to deserialize return value: {e}") + raise + + return serialize_method_parameters, deserialize_method_return + + except Exception as e: + logger.error(f"Failed to create serialization functions: {e}") + raise @staticmethod - def create_method_descriptor(func: Callable, method_name: str = None, + def create_method_descriptor(func: Callable, method_name: Optional[str] = None, parameter_types: List[Type] = None, return_type: Type = None, interface: Callable = None) -> MethodDescriptor: """Create a method descriptor from function and configuration""" - name = method_name or (interface.__name__ if interface else func.__name__) - sig = inspect.signature(interface if interface else func) + if not callable(func): + raise TypeError("func must be callable") + + # Use interface signature if provided, otherwise use func signature + target_function = interface if interface else func + name = method_name or target_function.__name__ + + try: + sig = inspect.signature(target_function) + except ValueError as e: + logger.error(f"Cannot inspect signature of {target_function}: {e}") + raise parameters = [] resolved_parameter_types = parameter_types or [] + param_index = 0 - for i, (param_name, param) in enumerate(sig.parameters.items()): + for param_name, param in sig.parameters.items(): + # Skip 'self' parameter for methods if param_name == 'self': continue - param_index = i - 1 if 'self' in sig.parameters else i - + # Get parameter type from provided types, annotation, or default to Any if param_index < len(resolved_parameter_types): param_type = resolved_parameter_types[param_index] elif param.annotation != inspect.Parameter.empty: @@ -138,7 +171,10 @@ def create_method_descriptor(func: Callable, method_name: str = None, is_required=is_required, default_value=default_value )) + + param_index += 1 + # Resolve return type if return_type: resolved_return_type = return_type elif sig.return_annotation != inspect.Signature.empty: @@ -156,5 +192,5 @@ def create_method_descriptor(func: Callable, method_name: str = None, name=name, parameters=parameters, return_parameter=return_parameter, - documentation=func.__doc__ + documentation=inspect.getdoc(target_function) ) \ No newline at end of file From 598703cf8db97ad11254f483ff62d295989f914f Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Wed, 3 Sep 2025 18:47:54 +0000 Subject: [PATCH 16/40] formating the code and remake the json serializer plugable architecture --- src/dubbo/classes.py | 3 +- src/dubbo/client.py | 50 +- src/dubbo/codec/__init__.py | 2 +- src/dubbo/codec/dubbo_codec.py | 85 ++-- src/dubbo/codec/json_codec/__init__.py | 2 +- .../codec/json_codec/json_codec_handler.py | 459 +++++++++++------- .../codec/json_codec/json_transport_base.py | 72 +++ .../json_codec/json_transport_plugins.py | 211 ++++++++ src/dubbo/codec/json_codec/json_type.py | 274 ----------- src/dubbo/codec/protobuf_codec/__init__.py | 4 +- src/dubbo/extension/registries.py | 4 +- src/dubbo/proxy/handlers.py | 3 +- src/dubbo/server.py | 2 +- tests/json/json_test.py | 92 ++-- tests/json/json_type_test.py | 15 +- tests/protobuf/generated/__init__.py | 2 +- 16 files changed, 687 insertions(+), 593 deletions(-) create mode 100644 src/dubbo/codec/json_codec/json_transport_base.py create mode 100644 src/dubbo/codec/json_codec/json_transport_plugins.py delete mode 100644 src/dubbo/codec/json_codec/json_type.py diff --git a/src/dubbo/classes.py b/src/dubbo/classes.py index d6ba2ab..b3076ea 100644 --- a/src/dubbo/classes.py +++ b/src/dubbo/classes.py @@ -16,7 +16,7 @@ import abc import threading -from typing import Any, Callable, Optional, Union,Type +from typing import Any, Callable, Optional, Union, Type from abc import ABC, abstractmethod from dubbo.types import DeserializingFunction, RpcType, RpcTypes, SerializingFunction @@ -259,6 +259,7 @@ def encode(self, data: Any) -> bytes: def decode(self, data: bytes) -> Any: pass + class CodecHelper: @staticmethod def get_class(): diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 2862228..6d1734f 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -63,17 +63,11 @@ def _initialize(self): return # get the protocol - protocol = extensionLoader.get_extension( - Protocol, self._reference.protocol - )() + protocol = extensionLoader.get_extension(Protocol, self._reference.protocol)() registry_config = self._dubbo.registry_config - self._protocol = ( - RegistryProtocol(registry_config, protocol) - if registry_config - else protocol - ) + self._protocol = RegistryProtocol(registry_config, protocol) if registry_config else protocol # build url reference_url = self._reference.to_url() @@ -109,7 +103,7 @@ def _create_rpc_callable( res_deser = response_deserializer else: req_ser, res_deser = DubboTransportService.create_serialization_functions( - codec or "json", + codec or "json", parameter_types=params_types, return_type=return_type, ) @@ -124,13 +118,7 @@ def _create_rpc_callable( return self._callable(descriptor) - def unary( - self, - method_name: str, - params_types: List[Type], - return_type: Type, - **kwargs - ) -> RpcCallable: + def unary(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.UNARY.value, method_name=method_name, @@ -139,13 +127,7 @@ def unary( **kwargs, ) - def client_stream( - self, - method_name: str, - params_types: List[Type], - return_type: Type, - **kwargs - ) -> RpcCallable: + def client_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.CLIENT_STREAM.value, method_name=method_name, @@ -154,13 +136,7 @@ def client_stream( **kwargs, ) - def server_stream( - self, - method_name: str, - params_types: List[Type], - return_type: Type, - **kwargs - ) -> RpcCallable: + def server_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.SERVER_STREAM.value, method_name=method_name, @@ -169,13 +145,7 @@ def server_stream( **kwargs, ) - def bi_stream( - self, - method_name: str, - params_types: List[Type], - return_type: Type, - **kwargs - ) -> RpcCallable: + def bi_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.BI_STREAM.value, method_name=method_name, @@ -189,8 +159,6 @@ def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: Generate a proxy for the given method. """ url = self._invoker.get_url().copy() - url.parameters[common_constants.METHOD_KEY] = ( - method_descriptor.get_method_name() - ) + url.parameters[common_constants.METHOD_KEY] = method_descriptor.get_method_name() url.attributes[common_constants.METHOD_DESCRIPTOR_KEY] = method_descriptor - return self._callable_factory.get_callable(self._invoker, url) \ No newline at end of file + return self._callable_factory.get_callable(self._invoker, url) diff --git a/src/dubbo/codec/__init__.py b/src/dubbo/codec/__init__.py index dfd1b56..72e6f3b 100644 --- a/src/dubbo/codec/__init__.py +++ b/src/dubbo/codec/__init__.py @@ -16,4 +16,4 @@ from .dubbo_codec import DubboTransportService -__all__ = ['DubboTransportService'] \ No newline at end of file +__all__ = ["DubboTransportService"] diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index a154de7..3309c32 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -25,6 +25,7 @@ @dataclass class ParameterDescriptor: """Detailed information about a method parameter""" + name: str annotation: Any is_required: bool = True @@ -34,6 +35,7 @@ class ParameterDescriptor: @dataclass class MethodDescriptor: """Complete method descriptor with all necessary information""" + function: Callable name: str parameters: List[ParameterDescriptor] @@ -45,20 +47,17 @@ class DubboSerializationService: """Dubbo serialization service with robust type handling""" @staticmethod - def create_transport_codec(transport_type: str = 'json', parameter_types: List[Type] = None, - return_type: Type = None, **codec_options): + def create_transport_codec( + transport_type: str = "json", parameter_types: List[Type] = None, return_type: Type = None, **codec_options + ): """Create transport codec with enhanced parameter structure""" - + try: from dubbo.extension.extension_loader import ExtensionLoader from dubbo.classes import CodecHelper - + codec_class = ExtensionLoader().get_extension(CodecHelper.get_class(), transport_type) - return codec_class( - parameter_types=parameter_types or [], - return_type=return_type, - **codec_options - ) + return codec_class(parameter_types=parameter_types or [], return_type=return_type, **codec_options) except ImportError as e: logger.error(f"Failed to import required modules: {e}") raise @@ -67,41 +66,37 @@ def create_transport_codec(transport_type: str = 'json', parameter_types: List[T raise @staticmethod - def create_encoder_decoder_pair(transport_type: str, parameter_types: List[Type] = None, - return_type: Type = None, **codec_options) -> Tuple[Any, Any]: + def create_encoder_decoder_pair( + transport_type: str, parameter_types: List[Type] = None, return_type: Type = None, **codec_options + ) -> Tuple[Any, Any]: """Create separate encoder and decoder instances""" try: codec_instance = DubboSerializationService.create_transport_codec( - transport_type=transport_type, - parameter_types=parameter_types, - return_type=return_type, - **codec_options + transport_type=transport_type, parameter_types=parameter_types, return_type=return_type, **codec_options ) - + encoder = codec_instance.get_encoder() decoder = codec_instance.get_decoder() - + if encoder is None or decoder is None: raise ValueError(f"Codec for transport type '{transport_type}' returned None encoder/decoder") - + return encoder, decoder - + except Exception as e: logger.error(f"Failed to create encoder/decoder pair: {e}") raise @staticmethod - def create_serialization_functions(transport_type: str, parameter_types: List[Type] = None, - return_type: Type = None, **codec_options) -> Tuple[Callable, Callable]: + def create_serialization_functions( + transport_type: str, parameter_types: List[Type] = None, return_type: Type = None, **codec_options + ) -> Tuple[Callable, Callable]: """Create serializer and deserializer functions for RPC (backward compatibility)""" try: parameter_encoder, return_decoder = DubboSerializationService.create_encoder_decoder_pair( - transport_type=transport_type, - parameter_types=parameter_types, - return_type=return_type, - **codec_options + transport_type=transport_type, parameter_types=parameter_types, return_type=return_type, **codec_options ) def serialize_method_parameters(*args) -> bytes: @@ -119,17 +114,21 @@ def deserialize_method_return(data: bytes): except Exception as e: logger.error(f"Failed to deserialize return value: {e}") raise - + return serialize_method_parameters, deserialize_method_return - + except Exception as e: logger.error(f"Failed to create serialization functions: {e}") raise @staticmethod - def create_method_descriptor(func: Callable, method_name: Optional[str] = None, - parameter_types: List[Type] = None, return_type: Type = None, - interface: Callable = None) -> MethodDescriptor: + def create_method_descriptor( + func: Callable, + method_name: Optional[str] = None, + parameter_types: List[Type] = None, + return_type: Type = None, + interface: Callable = None, + ) -> MethodDescriptor: """Create a method descriptor from function and configuration""" if not callable(func): @@ -138,7 +137,7 @@ def create_method_descriptor(func: Callable, method_name: Optional[str] = None, # Use interface signature if provided, otherwise use func signature target_function = interface if interface else func name = method_name or target_function.__name__ - + try: sig = inspect.signature(target_function) except ValueError as e: @@ -151,7 +150,7 @@ def create_method_descriptor(func: Callable, method_name: Optional[str] = None, for param_name, param in sig.parameters.items(): # Skip 'self' parameter for methods - if param_name == 'self': + if param_name == "self": continue # Get parameter type from provided types, annotation, or default to Any @@ -165,13 +164,12 @@ def create_method_descriptor(func: Callable, method_name: Optional[str] = None, is_required = param.default == inspect.Parameter.empty default_value = param.default if not is_required else None - parameters.append(ParameterDescriptor( - name=param_name, - annotation=param_type, - is_required=is_required, - default_value=default_value - )) - + parameters.append( + ParameterDescriptor( + name=param_name, annotation=param_type, is_required=is_required, default_value=default_value + ) + ) + param_index += 1 # Resolve return type @@ -182,15 +180,12 @@ def create_method_descriptor(func: Callable, method_name: Optional[str] = None, else: resolved_return_type = Any - return_parameter = ParameterDescriptor( - name="return_value", - annotation=resolved_return_type - ) + return_parameter = ParameterDescriptor(name="return_value", annotation=resolved_return_type) return MethodDescriptor( function=func, name=name, parameters=parameters, return_parameter=return_parameter, - documentation=inspect.getdoc(target_function) - ) \ No newline at end of file + documentation=inspect.getdoc(target_function), + ) diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py index f66e01d..a11163f 100644 --- a/src/dubbo/codec/json_codec/__init__.py +++ b/src/dubbo/codec/json_codec/__init__.py @@ -14,6 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .json_codec_handler import JsonTransportCodec,JsonTransportDecoder,JsonTransportEncoder +from .json_codec_handler import JsonTransportCodec, JsonTransportEncoder, JsonTransportDecoder __all__ = ["JsonTransportCodec", "JsonTransportDecoder", "JsonTransportEncoder"] diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index b7533c0..5e1d97f 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -14,302 +14,372 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Type, List, Union, Dict, TypeVar, Protocol -from datetime import datetime, date, time -from decimal import Decimal -from pathlib import Path -from uuid import UUID -import json - -from .json_type import ( - TypeProviderFactory, SerializationState, - SerializationException, DeserializationException +from typing import Any, Type, List, Union, Optional +from .json_transport_base import SimpleRegistry, SerializationException, DeserializationException +from .json_transport_plugins import ( + StandardJsonPlugin, + OrJsonPlugin, + UJsonPlugin, + DateTimeHandler, + DecimalHandler, + CollectionHandler, + SimpleTypeHandler, + PydanticHandler, + DataclassHandler, + EnumHandler, ) -try: - import orjson - HAS_ORJSON = True -except ImportError: - HAS_ORJSON = False - -try: - import ujson - HAS_UJSON = True -except ImportError: - HAS_UJSON = False - -try: - from pydantic import BaseModel, create_model - HAS_PYDANTIC = True -except ImportError: - HAS_PYDANTIC = False - - -class EncodingFunction(Protocol): - def __call__(self, obj: Any) -> bytes: ... - - -class DecodingFunction(Protocol): - def __call__(self, data: bytes) -> Any: ... - - -ModelT = TypeVar('ModelT', bound=BaseModel) - - -class CustomJSONEncoder(json.JSONEncoder): - def default(self, obj): - if isinstance(obj, datetime): - return { - "__datetime__": obj.isoformat(), - "__timezone__": str(obj.tzinfo) if obj.tzinfo else None - } - elif isinstance(obj, date): - return {"__date__": obj.isoformat()} - elif isinstance(obj, time): - return {"__time__": obj.isoformat()} - elif isinstance(obj, Decimal): - return {"__decimal__": str(obj)} - elif isinstance(obj, (set, frozenset)): - return { - "__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj) - } - elif isinstance(obj, UUID): - return {"__uuid__": str(obj)} - elif isinstance(obj, Path): - return {"__path__": str(obj)} - else: - return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} - class JsonTransportEncoder: - def __init__(self, parameter_types: List[Type] = None, maximum_depth: int = 100, - strict_validation: bool = True, **kwargs): + """JSON Transport Encoder with plugin architecture""" + + def __init__( + self, parameter_types: List[Type] = None, maximum_depth: int = 100, strict_validation: bool = True, **kwargs + ): self.parameter_types = parameter_types or [] self.maximum_depth = maximum_depth self.strict_validation = strict_validation - self.type_registry = TypeProviderFactory.create_default_registry() - self.custom_encoder = CustomJSONEncoder(ensure_ascii=False, separators=(',', ':')) - self.single_parameter_mode = len(self.parameter_types) == 1 - self.multiple_parameter_mode = len(self.parameter_types) > 1 - if self.multiple_parameter_mode and HAS_PYDANTIC: - self.parameter_wrapper_model = self._create_parameter_wrapper_model() - - def _create_parameter_wrapper_model(self) -> Type[BaseModel]: - model_fields = {} - for i, param_type in enumerate(self.parameter_types): - model_fields[f"parameter_{i}"] = (param_type, ...) - return create_model('MethodParametersWrapper', **model_fields) - - def register_type_provider(self, provider) -> None: - self.type_registry.register_provider(provider) + self.registry = SimpleRegistry() + self.json_plugins = [] + + # Setup plugins + self._register_default_type_plugins() + self._setup_json_serializer_plugins() + + def _register_default_type_plugins(self): + """Register default type handler plugins""" + default_plugins = [ + DateTimeHandler(), + DecimalHandler(), + CollectionHandler(), + SimpleTypeHandler(), + DataclassHandler(), + EnumHandler(), + ] + + # Add Pydantic plugin if available + pydantic_plugin = PydanticHandler() + if pydantic_plugin.available: + default_plugins.append(pydantic_plugin) + + for plugin in default_plugins: + self.registry.register_plugin(plugin) + + def _setup_json_serializer_plugins(self): + """Setup JSON serializer plugins in priority order""" + # Try orjson first (fastest), then ujson, finally standard json + orjson_plugin = OrJsonPlugin() + if orjson_plugin.available: + self.json_plugins.append(orjson_plugin) + + ujson_plugin = UJsonPlugin() + if ujson_plugin.available: + self.json_plugins.append(ujson_plugin) + + # Always have standard json as fallback + self.json_plugins.append(StandardJsonPlugin()) + + def register_type_provider(self, provider): + """Register custom type provider for backward compatibility""" + self.registry.register_plugin(provider) def encode(self, arguments: tuple) -> bytes: + """Encode arguments with flexible parameter handling""" try: if not arguments: return self._serialize_to_json_bytes([]) - if self.single_parameter_mode: + # Handle single parameter case + if len(self.parameter_types) == 1: parameter = arguments[0] - serialized_param = self._serialize_with_state(parameter) - if HAS_PYDANTIC and isinstance(parameter, BaseModel): - if hasattr(parameter, 'model_dump'): - return self._serialize_to_json_bytes(parameter.model_dump()) - return self._serialize_to_json_bytes(parameter.dict()) - elif isinstance(parameter, dict): - return self._serialize_to_json_bytes(serialized_param) - else: - return self._serialize_to_json_bytes(serialized_param) - - elif self.multiple_parameter_mode and HAS_PYDANTIC: - wrapper_data = {f"parameter_{i}": arg for i, arg in enumerate(arguments)} - wrapper_instance = self.parameter_wrapper_model(**wrapper_data) - return self._serialize_to_json_bytes(wrapper_instance.model_dump()) + serialized_param = self._serialize_object(parameter) + return self._serialize_to_json_bytes(serialized_param) + + # Handle multiple parameters + elif len(self.parameter_types) > 1: + # Try Pydantic wrapper for strong typing + pydantic_handler = self._get_pydantic_handler() + if pydantic_handler and pydantic_handler.available: + wrapper_data = {f"param_{i}": arg for i, arg in enumerate(arguments)} + wrapper_model = pydantic_handler.create_parameter_model(self.parameter_types) + if wrapper_model: + try: + wrapper_instance = wrapper_model(**wrapper_data) + return self._serialize_to_json_bytes(pydantic_handler.serialize_to_dict(wrapper_instance)) + except Exception: + pass # Fall back to standard handling + + # Standard multi-parameter handling + serialized_args = [self._serialize_object(arg) for arg in arguments] + return self._serialize_to_json_bytes(serialized_args) else: - serialized_args = [self._serialize_with_state(arg) for arg in arguments] + # No type constraints - serialize as list + serialized_args = [self._serialize_object(arg) for arg in arguments] return self._serialize_to_json_bytes(serialized_args) except Exception as e: raise SerializationException(f"Encoding failed: {e}") from e - def _serialize_with_state(self, obj: Any) -> Any: - state = SerializationState(maximum_depth=self.maximum_depth) - return self._serialize_recursively(obj, state) + def _get_pydantic_handler(self) -> Optional[PydanticHandler]: + """Get Pydantic handler from registered plugins""" + for plugin in self.registry.plugins: + if isinstance(plugin, PydanticHandler): + return plugin + return None + + def _serialize_object(self, obj: Any, depth: int = 0) -> Any: + """Serialize single object using registry with depth protection""" + if depth > self.maximum_depth: + raise SerializationException(f"Maximum depth {self.maximum_depth} exceeded") - def _serialize_recursively(self, obj: Any, state: SerializationState) -> Any: if obj is None or isinstance(obj, (bool, int, float, str)): return obj + if isinstance(obj, (list, tuple)): - state.validate_circular_reference(obj) - new_state = state.create_child_state(obj) - return [self._serialize_recursively(item, new_state) for item in obj] + return [self._serialize_object(item, depth + 1) for item in obj] + elif isinstance(obj, dict): - state.validate_circular_reference(obj) - new_state = state.create_child_state(obj) result = {} for key, value in obj.items(): if not isinstance(key, str): if self.strict_validation: raise SerializationException(f"Dictionary key must be string, got {type(key).__name__}") key = str(key) - result[key] = self._serialize_recursively(value, new_state) + result[key] = self._serialize_object(value, depth + 1) return result - provider = self.type_registry.find_provider_for_object(obj) - if provider: + # Use registry to find handler + handler = self.registry.get_handler(obj) + if handler: try: - serialized = provider.serialize_to_dict(obj, state) - return self._serialize_recursively(serialized, state) + return handler(obj) except Exception as e: if self.strict_validation: - raise SerializationException(f"Provider failed for {type(obj).__name__}: {e}") from e - return {"__serialization_error__": str(e), "__original_type__": type(obj).__name__} - else: - if self.strict_validation: - raise SerializationException(f"No provider for type {type(obj).__name__}") - return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + raise SerializationException(f"Handler failed for {type(obj).__name__}: {e}") from e + return {"__serialization_error__": str(e), "__type__": type(obj).__name__} + + # Fallback for unknown types + if self.strict_validation: + raise SerializationException(f"No handler for type {type(obj).__name__}") + return {"__fallback__": str(obj), "__type__": type(obj).__name__} def _serialize_to_json_bytes(self, obj: Any) -> bytes: - if HAS_ORJSON: + """Use the first available JSON plugin to serialize""" + last_error = None + for plugin in self.json_plugins: try: - return orjson.dumps(obj, default=self._orjson_default_handler) - except TypeError: - pass - if HAS_UJSON: - try: - return ujson.dumps(obj, ensure_ascii=False, default=self._ujson_default_handler).encode('utf-8') - except (TypeError, ValueError): - pass - return self.custom_encoder.encode(obj).encode('utf-8') - - def _orjson_default_handler(self, obj): - if isinstance(obj, datetime): - return { - "__datetime__": obj.isoformat(), - "__timezone__": str(obj.tzinfo) if obj.tzinfo else None - } - elif isinstance(obj, date): - return {"__date__": obj.isoformat()} - elif isinstance(obj, time): - return {"__time__": obj.isoformat()} - elif isinstance(obj, Decimal): - return {"__decimal__": str(obj)} - elif isinstance(obj, (set, frozenset)): - return { - "__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj) - } - elif isinstance(obj, UUID): - return {"__uuid__": str(obj)} - elif isinstance(obj, Path): - return {"__path__": str(obj)} - else: - return {"__fallback_string__": str(obj), "__original_type__": type(obj).__name__} + return plugin.encode(obj) + except Exception as e: + last_error = e + continue - def _ujson_default_handler(self, obj): - return self._orjson_default_handler(obj) + raise SerializationException(f"All JSON plugins failed. Last error: {last_error}") class JsonTransportDecoder: + """JSON Transport Decoder with plugin architecture""" + def __init__(self, target_type: Union[Type, List[Type]] = None, **kwargs): self.target_type = target_type + self.json_plugins = [] + self._setup_json_deserializer_plugins() + + # Handle multiple parameter types if isinstance(target_type, list): self.multiple_parameter_mode = len(target_type) > 1 self.parameter_types = target_type - if self.multiple_parameter_mode and HAS_PYDANTIC: - self.parameter_wrapper_model = self._create_parameter_wrapper_model() + if self.multiple_parameter_mode: + pydantic_handler = PydanticHandler() + if pydantic_handler.available: + self.parameter_wrapper_model = pydantic_handler.create_parameter_model(target_type) else: self.multiple_parameter_mode = False self.parameter_types = [target_type] if target_type else [] - def _create_parameter_wrapper_model(self) -> Type[BaseModel]: - model_fields = {} - for i, param_type in enumerate(self.parameter_types): - model_fields[f"parameter_{i}"] = (param_type, ...) - return create_model('MethodParametersWrapper', **model_fields) + def _setup_json_deserializer_plugins(self): + """Setup JSON deserializer plugins in priority order""" + orjson_plugin = OrJsonPlugin() + if orjson_plugin.available: + self.json_plugins.append(orjson_plugin) + + ujson_plugin = UJsonPlugin() + if ujson_plugin.available: + self.json_plugins.append(ujson_plugin) + + self.json_plugins.append(StandardJsonPlugin()) def decode(self, data: bytes) -> Any: + """Decode JSON bytes back to objects""" try: if not data: return None + json_data = self._deserialize_from_json_bytes(data) reconstructed_data = self._reconstruct_objects(json_data) + if not self.target_type: return reconstructed_data + if isinstance(self.target_type, list): - if self.multiple_parameter_mode and HAS_PYDANTIC: - wrapper_instance = self.parameter_wrapper_model(**reconstructed_data) - return tuple(getattr(wrapper_instance, f"parameter_{i}") for i in range(len(self.parameter_types))) - else: - return self._decode_to_target_type(reconstructed_data, self.parameter_types[0]) + if self.multiple_parameter_mode and hasattr(self, "parameter_wrapper_model"): + try: + wrapper_instance = self.parameter_wrapper_model(**reconstructed_data) + return tuple(getattr(wrapper_instance, f"param_{i}") for i in range(len(self.parameter_types))) + except Exception: + pass + return self._decode_to_target_type(reconstructed_data, self.parameter_types[0]) else: return self._decode_to_target_type(reconstructed_data, self.target_type) + except Exception as e: raise DeserializationException(f"Decoding failed: {e}") from e def _deserialize_from_json_bytes(self, data: bytes) -> Any: - if HAS_ORJSON: - try: - return orjson.loads(data) - except orjson.JSONDecodeError: - pass - if HAS_UJSON: + """Use the first available JSON plugin to deserialize""" + last_error = None + for plugin in self.json_plugins: try: - return ujson.loads(data.decode('utf-8')) - except (ujson.JSONDecodeError, UnicodeDecodeError): - pass - return json.loads(data.decode('utf-8')) + return plugin.decode(data) + except Exception as e: + last_error = e + continue + + raise DeserializationException(f"All JSON plugins failed. Last error: {last_error}") def _decode_to_target_type(self, json_data: Any, target_type: Type) -> Any: + """Convert JSON data to target type with proper Pydantic handling""" + # Check if target type is a Pydantic model + try: + from pydantic import BaseModel + + if isinstance(target_type, type) and issubclass(target_type, BaseModel): + # If json_data is already a Pydantic model instance, return as-is + if isinstance(json_data, target_type): + return json_data + # If json_data is a dict, try to construct the Pydantic model + elif isinstance(json_data, dict): + return target_type(**json_data) + except ImportError: + pass + if target_type in (str, int, float, bool, list, dict): return target_type(json_data) + return json_data def _reconstruct_objects(self, data: Any) -> Any: + """Reconstruct special objects from their serialized form""" if not isinstance(data, dict): if isinstance(data, list): return [self._reconstruct_objects(item) for item in data] return data + + # Handle special serialized objects if "__datetime__" in data: + from datetime import datetime + return datetime.fromisoformat(data["__datetime__"]) elif "__date__" in data: + from datetime import date + return date.fromisoformat(data["__date__"]) elif "__time__" in data: + from datetime import time + return time.fromisoformat(data["__time__"]) elif "__decimal__" in data: + from decimal import Decimal + return Decimal(data["__decimal__"]) elif "__set__" in data: return set(self._reconstruct_objects(item) for item in data["__set__"]) elif "__frozenset__" in data: return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) elif "__uuid__" in data: + from uuid import UUID + return UUID(data["__uuid__"]) elif "__path__" in data: + from pathlib import Path + return Path(data["__path__"]) - elif "__dataclass__" in data or "__pydantic_model__" in data: - return data + elif "__pydantic_model__" in data and "__model_data__" in data: + # Properly reconstruct Pydantic models + return self._reconstruct_pydantic_model(data) + elif "__dataclass__" in data: + module_name, class_name = data["__dataclass__"].rsplit(".", 1) + import importlib + + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + fields = self._reconstruct_objects(data["fields"]) + return cls(**fields) + + elif "__enum__" in data: + module_name, class_name = data["__enum__"].rsplit(".", 1) + import importlib + + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + return cls(data["value"]) else: return {key: self._reconstruct_objects(value) for key, value in data.items()} + def _reconstruct_pydantic_model(self, data: dict) -> Any: + """Reconstruct a Pydantic model from serialized data""" + try: + model_path = data["__pydantic_model__"] + model_data = data["__model_data__"] + + # Try to import and reconstruct the model + module_name, class_name = model_path.rsplit(".", 1) + + # Import the module + import importlib + + module = importlib.import_module(module_name) + model_class = getattr(module, class_name) + + # Recursively reconstruct nested objects in model_data + reconstructed_data = self._reconstruct_objects(model_data) + + # Create the model instance + return model_class(**reconstructed_data) + + except Exception as e: + # If reconstruction fails, return the model data as dict + # This allows the target type conversion to handle it + return self._reconstruct_objects(data.get("__model_data__", {})) + class JsonTransportCodec: - def __init__(self, parameter_types: List[Type] = None, return_type: Type = None, - maximum_depth: int = 100, strict_validation: bool = True, **kwargs): + """Combined JSON transport codec - maintains backward compatibility""" + + def __init__( + self, + parameter_types: List[Type] = None, + return_type: Type = None, + maximum_depth: int = 100, + strict_validation: bool = True, + **kwargs, + ): self.parameter_types = parameter_types or [] self.return_type = return_type self.maximum_depth = maximum_depth self.strict_validation = strict_validation + self._encoder = JsonTransportEncoder( - parameter_types=parameter_types, - maximum_depth=maximum_depth, - strict_validation=strict_validation, - **kwargs + parameter_types=parameter_types, maximum_depth=maximum_depth, strict_validation=strict_validation, **kwargs ) self._decoder = JsonTransportDecoder(target_type=return_type, **kwargs) def encode_parameters(self, *arguments) -> bytes: + """Encode parameters - supports both positional and keyword args""" return self._encoder.encode(arguments) def decode_return_value(self, data: bytes) -> Any: + """Decode return value""" return self._decoder.decode(data) def get_encoder(self) -> JsonTransportEncoder: @@ -319,4 +389,29 @@ def get_decoder(self) -> JsonTransportDecoder: return self._decoder def register_type_provider(self, provider) -> None: + """Register custom type provider""" self._encoder.register_type_provider(provider) + + +class JsonTransportService: + """Service for extension loader integration""" + + def __init__(self, parameter_types: List[Type] = None, return_type: Type = None, **kwargs): + self.codec = JsonTransportCodec(parameter_types=parameter_types, return_type=return_type, **kwargs) + + def encode_parameters(self, *args, **kwargs) -> bytes: + """Enhanced to handle both positional and keyword arguments""" + if kwargs: + # Convert keyword args to positional for now + # Could be enhanced to preserve keyword information + combined_args = args + tuple(kwargs.values()) + return self.codec.encode_parameters(*combined_args) + return self.codec.encode_parameters(*args) + + def decode_return_value(self, data: bytes) -> Any: + """Decode return value""" + return self.codec.decode_return_value(data) + + def register_custom_plugin(self, plugin): + """Allow extension loader to register custom plugins""" + self.codec.register_type_provider(plugin) diff --git a/src/dubbo/codec/json_codec/json_transport_base.py b/src/dubbo/codec/json_codec/json_transport_base.py new file mode 100644 index 0000000..d1ead52 --- /dev/null +++ b/src/dubbo/codec/json_codec/json_transport_base.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Dict, Protocol, Callable + + +class JsonSerializerPlugin(Protocol): + """Protocol for JSON serialization plugins""" + + def encode(self, obj: Any) -> bytes: ... + def decode(self, data: bytes) -> Any: ... + def can_handle(self, obj: Any) -> bool: ... + + +class TypeHandlerPlugin(Protocol): + """Protocol for type-specific serialization""" + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: ... + def serialize_to_dict(self, obj: Any) -> Any: ... + + +class SimpleRegistry: + """Simplified registry using dict instead of complex TypeProviderRegistry""" + + def __init__(self): + # Simple dict mapping: type -> handler function + self.type_handlers: Dict[type, Callable] = {} + self.plugins: List[TypeHandlerPlugin] = [] + + def register_type_handler(self, obj_type: type, handler: Callable): + """Register a simple type handler function""" + self.type_handlers[obj_type] = handler + + def register_plugin(self, plugin: TypeHandlerPlugin): + """Register a plugin""" + self.plugins.append(plugin) + + def get_handler(self, obj: Any) -> Callable: + """Get handler for object - check dict first, then plugins""" + obj_type = type(obj) + if obj_type in self.type_handlers: + return self.type_handlers[obj_type] + + for plugin in self.plugins: + if plugin.can_serialize_type(obj, obj_type): + return plugin.serialize_to_dict + return None + + +class SerializationException(Exception): + """Exception raised during serialization""" + + pass + + +class DeserializationException(Exception): + """Exception raised during deserialization""" + + pass diff --git a/src/dubbo/codec/json_codec/json_transport_plugins.py b/src/dubbo/codec/json_codec/json_transport_plugins.py new file mode 100644 index 0000000..4c6d936 --- /dev/null +++ b/src/dubbo/codec/json_codec/json_transport_plugins.py @@ -0,0 +1,211 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Type, List, Union, Dict +from datetime import datetime, date, time +from decimal import Decimal +from pathlib import Path +from uuid import UUID +from enum import Enum +from dataclasses import is_dataclass, asdict +from typing import Any, Dict, List + + +class StandardJsonPlugin: + """Standard library JSON plugin""" + + def encode(self, obj: Any) -> bytes: + return json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + + def decode(self, data: bytes) -> Any: + return json.loads(data.decode("utf-8")) + + def can_handle(self, obj: Any) -> bool: + return True + + +class OrJsonPlugin: + """orjson plugin - separate from standard json""" + + def __init__(self): + try: + import orjson + + self.orjson = orjson + self.available = True + except ImportError: + self.available = False + + def encode(self, obj: Any) -> bytes: + if not self.available: + raise ImportError("orjson not available") + return self.orjson.dumps(obj, default=self._default_handler) + + def decode(self, data: bytes) -> Any: + if not self.available: + raise ImportError("orjson not available") + return self.orjson.loads(data) + + def can_handle(self, obj: Any) -> bool: + return self.available + + def _default_handler(self, obj): + """Handle types that orjson doesn't support natively""" + if isinstance(obj, datetime): + return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + elif isinstance(obj, (date, time)): + return {"__date__": obj.isoformat()} if isinstance(obj, date) else {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, (set, frozenset)): + return {"__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj)} + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + return {"__fallback__": str(obj), "__type__": type(obj).__name__} + + +class UJsonPlugin: + """ujson plugin - separate from others""" + + def __init__(self): + try: + import ujson + + self.ujson = ujson + self.available = True + except ImportError: + self.available = False + + def encode(self, obj: Any) -> bytes: + if not self.available: + raise ImportError("ujson not available") + return self.ujson.dumps(obj, ensure_ascii=False, default=self._default_handler).encode("utf-8") + + def decode(self, data: bytes) -> Any: + if not self.available: + raise ImportError("ujson not available") + return self.ujson.loads(data.decode("utf-8")) + + def can_handle(self, obj: Any) -> bool: + return self.available + + def _default_handler(self, obj): + """Same as orjson handler""" + return OrJsonPlugin()._default_handler(obj) + + +# Type Handler Plugins (properly inherit protocol) +class DateTimeHandler: + """DateTime handler - implements TypeHandlerPlugin protocol""" + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type in (datetime, date, time) + + def serialize_to_dict(self, obj: Union[datetime, date, time]) -> Dict[str, str]: + if isinstance(obj, datetime): + return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + else: + return {"__time__": obj.isoformat()} + + +class DecimalHandler: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type is Decimal + + def serialize_to_dict(self, obj: Decimal) -> Dict[str, str]: + return {"__decimal__": str(obj)} + + +class CollectionHandler: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type in (set, frozenset) + + def serialize_to_dict(self, obj: Union[set, frozenset]) -> Dict[str, List]: + return {"__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj)} + + +class EnumHandler: + """Handles serialization of Enum types""" + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return isinstance(obj, Enum) + + def serialize_to_dict(self, obj: Enum) -> Dict[str, Any]: + return {"__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "value": obj.value} + + +class DataclassHandler: + """Handles serialization of dataclass types""" + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return is_dataclass(obj) + + def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + return {"__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "fields": asdict(obj)} + + +class SimpleTypeHandler: + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return obj_type in (UUID, Path) or isinstance(obj, Path) + + def serialize_to_dict(self, obj: Union[UUID, Path]) -> Dict[str, str]: + if isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + + +class PydanticHandler: + """Separate Pydantic plugin with enhanced features""" + + def __init__(self): + try: + from pydantic import BaseModel, create_model + + self.BaseModel = BaseModel + self.create_model = create_model + self.available = True + except ImportError: + self.available = False + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + return self.available and isinstance(obj, self.BaseModel) + + def serialize_to_dict(self, obj) -> Dict[str, Any]: + if hasattr(obj, "model_dump"): + return { + "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__model_data__": obj.model_dump(), + } + return { + "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__model_data__": obj.dict(), + } + + def create_parameter_model(self, parameter_types: List[Type]): + """Enhanced parameter handling for both positional and keyword args""" + if not self.available: + return None + + model_fields = {} + for i, param_type in enumerate(parameter_types): + model_fields[f"param_{i}"] = (param_type, ...) + return self.create_model("ParametersModel", **model_fields) diff --git a/src/dubbo/codec/json_codec/json_type.py b/src/dubbo/codec/json_codec/json_type.py deleted file mode 100644 index 3a6d840..0000000 --- a/src/dubbo/codec/json_codec/json_type.py +++ /dev/null @@ -1,274 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import ABC, abstractmethod -from typing import ( - Any, - Type, - Optional, - List, - Dict, - Set, - Protocol, - runtime_checkable, - Union, -) -from dataclasses import dataclass, fields, is_dataclass, asdict -from datetime import datetime, date, time -from decimal import Decimal -from collections import namedtuple -from pathlib import Path -from uuid import UUID -from enum import Enum -import weakref - -try: - from pydantic import BaseModel - - HAS_PYDANTIC = True -except ImportError: - HAS_PYDANTIC = False - - -class SerializationException(Exception): - """Exception raised during serialization""" - pass - -class DeserializationException(Exception): - """Exception raised during deserialization""" - pass - -class CircularReferenceException(SerializationException): - """Exception raised when circular references are detected""" - pass - -@dataclass(frozen=True) -class SerializationState: - _visited_objects: Set[int] = None - maximum_depth: int = 100 - current_depth: int = 0 - - def __post_init__(self): - if self._visited_objects is None: - object.__setattr__(self, "_visited_objects", set()) - - def validate_circular_reference(self, obj: Any) -> None: - object_id = id(obj) - if object_id in self._visited_objects: - raise CircularReferenceException( - f"Circular reference detected for {type(obj).__name__}" - ) - if self.current_depth >= self.maximum_depth: - raise SerializationException( - f"Maximum serialization depth ({self.maximum_depth}) exceeded" - ) - - def create_child_state(self, obj: Any) -> "SerializationState": - new_visited = self._visited_objects.copy() - new_visited.add(id(obj)) - return SerializationState( - _visited_objects=new_visited, - maximum_depth=self.maximum_depth, - current_depth=self.current_depth + 1, - ) - - -@runtime_checkable -class TypeSerializationProvider(Protocol): - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: ... - - def serialize_to_dict(self, obj: Any, state: SerializationState) -> Any: ... - - -class TypeProviderRegistry: - def __init__(self): - self._type_cache: Dict[type, Optional[TypeSerializationProvider]] = {} - self._providers: List[TypeSerializationProvider] = [] - self._weak_cache = weakref.WeakKeyDictionary() - - def register_provider(self, provider: TypeSerializationProvider) -> None: - self._providers.append(provider) - self._type_cache.clear() - self._weak_cache.clear() - - def find_provider_for_object(self, obj: Any) -> Optional[TypeSerializationProvider]: - obj_type = type(obj) - if obj_type in self._type_cache: - return self._type_cache[obj_type] - provider = None - for p in self._providers: - if p.can_serialize_type(obj, obj_type): - provider = p - break - self._type_cache[obj_type] = provider - return provider - - -class DateTimeSerializationProvider: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type in (datetime, date, time) - - def serialize_to_dict( - self, obj: Union[datetime, date, time], state: SerializationState - ) -> Dict[str, str]: - if isinstance(obj, datetime): - return { - "__datetime__": obj.isoformat(), - "__timezone__": str(obj.tzinfo) if obj.tzinfo else None, - } - elif isinstance(obj, date): - return {"__date__": obj.isoformat()} - else: - return {"__time__": obj.isoformat()} - - -class DecimalSerializationProvider: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type is Decimal - - def serialize_to_dict( - self, obj: Decimal, state: SerializationState - ) -> Dict[str, str]: - return {"__decimal__": str(obj)} - - -class CollectionSerializationProvider: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type in (set, frozenset) - - def serialize_to_dict( - self, obj: Union[set, frozenset], state: SerializationState - ) -> Dict[str, Any]: - safe_items = [] - for item in obj: - if isinstance(item, (str, int, float, bool, type(None))): - safe_items.append(item) - else: - raise SerializationException( - f"Cannot serialize {type(item).__name__} in collection. " - f"Collections can only contain JSON-safe types (str, int, float, bool, None)" - ) - return { - "__frozenset__" if isinstance(obj, frozenset) else "__set__": safe_items - } - - -class DataclassSerializationProvider: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return is_dataclass(obj) and not isinstance(obj, type) - - def serialize_to_dict(self, obj: Any, state: SerializationState) -> Dict[str, Any]: - state.validate_circular_reference(obj) - try: - field_data = asdict(obj) - return { - "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", - "__field_data__": field_data, - } - except (TypeError, RecursionError): - field_data = {} - for field in fields(obj): - try: - field_data[field.name] = getattr(obj, field.name) - except Exception as e: - raise SerializationException( - f"Cannot serialize field '{field.name}': {e}" - ) - return { - "__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", - "__field_data__": field_data, - } - - -class NamedTupleSerializationProvider: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return ( - hasattr(obj_type, "_fields") - and hasattr(obj, "_asdict") - and callable(obj._asdict) - ) - - def serialize_to_dict(self, obj: Any, state: SerializationState) -> Dict[str, Any]: - return { - "__namedtuple__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", - "__tuple_data__": obj._asdict(), - } - - -class PydanticModelSerializationProvider: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return HAS_PYDANTIC and isinstance(obj, BaseModel) - - def serialize_to_dict( - self, obj: BaseModel, state: SerializationState - ) -> Dict[str, Any]: - state.validate_circular_reference(obj) - if hasattr(obj, "model_dump"): - model_data = obj.model_dump() - else: - model_data = obj.dict() - return { - "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", - "__model_data__": model_data, - } - - -class SimpleTypeSerializationProvider: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type is UUID or isinstance(obj, (Path, Enum)) - - def serialize_to_dict( - self, obj: Union[UUID, Path, Enum], state: SerializationState - ) -> Dict[str, str]: - if isinstance(obj, UUID): - return {"__uuid__": str(obj)} - elif isinstance(obj, Path): - return {"__path__": str(obj)} - else: - return { - "__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", - "__enum_value__": obj.value, - } - - -class TypeProviderFactory: - @staticmethod - def create_default_registry() -> TypeProviderRegistry: - registry = TypeProviderRegistry() - default_providers = [ - DateTimeSerializationProvider(), - DecimalSerializationProvider(), - CollectionSerializationProvider(), - DataclassSerializationProvider(), - NamedTupleSerializationProvider(), - PydanticModelSerializationProvider(), - SimpleTypeSerializationProvider(), - ] - for provider in default_providers: - registry.register_provider(provider) - return registry - - @staticmethod - def create_minimal_registry() -> TypeProviderRegistry: - registry = TypeProviderRegistry() - essential_providers = [ - DateTimeSerializationProvider(), - DecimalSerializationProvider(), - SimpleTypeSerializationProvider(), - ] - for provider in essential_providers: - registry.register_provider(provider) - return registry diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index 4ddd635..194815c 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -16,6 +16,4 @@ from .protobuf_codec_handler import ProtobufTransportCodec, ProtobufTransportEncoder, ProtobufTransportDecoder -__all__ = [ - "ProtobufTransportCodec", "ProtobufTransportEncoder", "ProtobufTransportDecoder" -] \ No newline at end of file +__all__ = ["ProtobufTransportCodec", "ProtobufTransportEncoder", "ProtobufTransportDecoder"] diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index d188d9c..ee52650 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -127,7 +127,7 @@ class ExtendedRegistry: "compressorRegistry", "decompressorRegistry", "transporterRegistry", - "codecRegistry", + "codecRegistry", ] # RegistryFactory registry @@ -190,4 +190,4 @@ class ExtendedRegistry: "json": "dubbo.codec.json_codec.JsonTransportCodec", "protobuf": "dubbo.codec.protobuf_codec.ProtobufTransportCodec", }, -) \ No newline at end of file +) diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index aa5004f..7c72516 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -31,6 +31,7 @@ class RpcMethodConfigurationError(Exception): """ Raised when RPC method is configured incorrectly. """ + pass @@ -142,7 +143,7 @@ def _create_method_descriptor( method_name=method_name or method.__name__, arg_serialization=(None, request_deserializer), return_serialization=(response_serializer, None), - rpc_type=rpc_type + rpc_type=rpc_type, ) @classmethod diff --git a/src/dubbo/server.py b/src/dubbo/server.py index 52a3507..b7c4dee 100644 --- a/src/dubbo/server.py +++ b/src/dubbo/server.py @@ -81,4 +81,4 @@ def start(self): self._protocol.export(self._url) - self._exported = True \ No newline at end of file + self._exported = True diff --git a/tests/json/json_test.py b/tests/json/json_test.py index 0bbbc04..9fe588f 100644 --- a/tests/json/json_test.py +++ b/tests/json/json_test.py @@ -15,54 +15,76 @@ # limitations under the License. import pytest -from datetime import datetime +from pathlib import Path +from uuid import UUID from decimal import Decimal -from uuid import uuid4 +from datetime import datetime, date, time +from dataclasses import dataclass +from enum import Enum +from pydantic import BaseModel +from dubbo.codec.json_codec import JsonTransportCodec -from dubbo.codec.json_codec.json_codec_handler import JsonTransportCodec -def test_json_single_parameter_roundtrip(): - codec = JsonTransportCodec(parameter_types=[int], return_type=int) +# Optional dataclass and enum examples +@dataclass +class SampleDataClass: + field1: int + field2: str - # Encode a single int - encoded = codec.encode_parameters(42) - assert isinstance(encoded, bytes) - # Decode back - decoded = codec.decode_return_value(encoded) - assert decoded == 42 +class Color(Enum): + RED = "red" + GREEN = "green" -def test_json_multiple_parameters_roundtrip(): - codec = JsonTransportCodec(parameter_types=[str, int], return_type=str) +class SamplePydanticModel(BaseModel): + name: str + value: int - # Encode multiple args - encoded = codec.encode_parameters("hello", 123) - assert isinstance(encoded, bytes) - # Decode return (simulate server returning str) - return_encoded = codec.get_encoder().encode(("world",)) - decoded = codec.decode_return_value(return_encoded) - assert decoded == "world" +# List of test cases: (input_value, expected_type_after_decoding) +test_cases = [ + ("simple string", str), + (12345, int), + (12.34, float), + (True, bool), + (datetime(2025, 8, 27, 13, 0, 0), datetime), + (date(2025, 8, 27), date), + (time(13, 0, 0), time), + (Decimal("123.45"), Decimal), + (set([1, 2, 3]), set), + (frozenset(["a", "b"]), frozenset), + (UUID("12345678-1234-5678-1234-567812345678"), UUID), + (Path("/tmp/file.txt"), Path), + (Color.RED, Color), + (SamplePydanticModel(name="test", value=42), SamplePydanticModel), +] -def test_json_complex_types(): - codec = JsonTransportCodec(parameter_types=[dict], return_type=dict) +@pytest.mark.parametrize("value,expected_type", test_cases) +def test_json_codec_roundtrip(value, expected_type): + codec = JsonTransportCodec(parameter_types=[type(value)], return_type=type(value)) - obj = { - "name": "Alice", - "when": datetime(2025, 8, 27, 12, 30), - "price": Decimal("19.99"), - "ids": {uuid4(), uuid4()} - } - - encoded = codec.encode_parameters(obj) + # Encode + encoded = codec.encode_parameters(value) assert isinstance(encoded, bytes) + # Decode decoded = codec.decode_return_value(encoded) - assert isinstance(decoded, dict) - assert decoded["name"] == "Alice" - assert isinstance(decoded["price"], Decimal) - assert isinstance(decoded["when"], datetime) - assert isinstance(decoded["ids"], set) + # For pydantic models, compare dict representation + if hasattr(value, "dict") and callable(value.dict): + assert decoded.dict() == value.dict() + # For dataclass, compare asdict + elif hasattr(value, "__dataclass_fields__"): + from dataclasses import asdict + + assert asdict(decoded) == asdict(value) + # For sets/frozensets, compare as sets + elif isinstance(value, (set, frozenset)): + assert decoded == value + # For enum + elif isinstance(value, Enum): + assert decoded.value == value.value + else: + assert decoded == value diff --git a/tests/json/json_type_test.py b/tests/json/json_type_test.py index bb7300a..3f146d0 100644 --- a/tests/json/json_type_test.py +++ b/tests/json/json_type_test.py @@ -25,20 +25,24 @@ from dubbo.codec.json_codec.json_codec_handler import JsonTransportCodec + # Optional dataclass and enum examples @dataclass class SampleDataClass: field1: int field2: str + class Color(Enum): RED = "red" GREEN = "green" + class SamplePydanticModel(BaseModel): name: str value: int + # List of test cases: (input_value, expected_type_after_decoding) test_cases = [ ("simple string", str), @@ -55,26 +59,28 @@ class SamplePydanticModel(BaseModel): (Path("/tmp/file.txt"), Path), (SampleDataClass(1, "abc"), SampleDataClass), (Color.RED, Color), - (SamplePydanticModel(name="test", value=42), SamplePydanticModel) + (SamplePydanticModel(name="test", value=42), SamplePydanticModel), ] + @pytest.mark.parametrize("value,expected_type", test_cases) def test_json_codec_roundtrip(value, expected_type): codec = JsonTransportCodec(parameter_types=[type(value)], return_type=type(value)) - + # Encode encoded = codec.encode_parameters(value) assert isinstance(encoded, bytes) - + # Decode decoded = codec.decode_return_value(encoded) - + # For pydantic models, compare dict representation if hasattr(value, "dict") and callable(value.dict): assert decoded.dict() == value.dict() # For dataclass, compare asdict elif hasattr(value, "__dataclass_fields__"): from dataclasses import asdict + assert asdict(decoded) == asdict(value) # For sets/frozensets, compare as sets elif isinstance(value, (set, frozenset)): @@ -84,4 +90,3 @@ def test_json_codec_roundtrip(value, expected_type): assert decoded.value == value.value else: assert decoded == value - diff --git a/tests/protobuf/generated/__init__.py b/tests/protobuf/generated/__init__.py index 4f1421a..bcba37a 100644 --- a/tests/protobuf/generated/__init__.py +++ b/tests/protobuf/generated/__init__.py @@ -12,4 +12,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. From 65de7d687a198841e33f1155cce6c7a2bae87766 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 4 Sep 2025 04:20:08 +0000 Subject: [PATCH 17/40] fixed all error issue by the reviewer --- .../codec/json_codec/json_codec_handler.py | 174 ++++--- ...ort_plugins.py => json_transport_codec.py} | 41 +- src/dubbo/codec/protobuf_codec/__init__.py | 8 +- .../protobuf_codec/protobuf_codec_handler.py | 458 +++++++++++------- src/dubbo/protocol/triple/protocol.py | 4 +- tests/protobuf/protobuf_test.py | 24 +- 6 files changed, 452 insertions(+), 257 deletions(-) rename src/dubbo/codec/json_codec/{json_transport_plugins.py => json_transport_codec.py} (83%) diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index 5e1d97f..7cc0d3e 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -16,7 +16,7 @@ from typing import Any, Type, List, Union, Optional from .json_transport_base import SimpleRegistry, SerializationException, DeserializationException -from .json_transport_plugins import ( +from .json_transport_codec import ( StandardJsonPlugin, OrJsonPlugin, UJsonPlugin, @@ -31,7 +31,7 @@ class JsonTransportEncoder: - """JSON Transport Encoder with plugin architecture""" + """JSON Transport Encoder""" def __init__( self, parameter_types: List[Type] = None, maximum_depth: int = 100, strict_validation: bool = True, **kwargs @@ -83,20 +83,20 @@ def register_type_provider(self, provider): """Register custom type provider for backward compatibility""" self.registry.register_plugin(provider) - def encode(self, arguments: tuple) -> bytes: + def encode(self, arguments: tuple, parameter_type: list = None) -> bytes: """Encode arguments with flexible parameter handling""" try: if not arguments: return self._serialize_to_json_bytes([]) # Handle single parameter case - if len(self.parameter_types) == 1: + if len(parameter_type) == 1: parameter = arguments[0] serialized_param = self._serialize_object(parameter) return self._serialize_to_json_bytes(serialized_param) # Handle multiple parameters - elif len(self.parameter_types) > 1: + elif len(parameter_type) > 1: # Try Pydantic wrapper for strong typing pydantic_handler = self._get_pydantic_handler() if pydantic_handler and pydantic_handler.available: @@ -105,7 +105,8 @@ def encode(self, arguments: tuple) -> bytes: if wrapper_model: try: wrapper_instance = wrapper_model(**wrapper_data) - return self._serialize_to_json_bytes(pydantic_handler.serialize_to_dict(wrapper_instance)) + serialized_wrapper = self._serialize_object(wrapper_instance) + return self._serialize_to_json_bytes(serialized_wrapper) except Exception: pass # Fall back to standard handling @@ -114,9 +115,14 @@ def encode(self, arguments: tuple) -> bytes: return self._serialize_to_json_bytes(serialized_args) else: - # No type constraints - serialize as list - serialized_args = [self._serialize_object(arg) for arg in arguments] - return self._serialize_to_json_bytes(serialized_args) + # No type constraints - serialize as single object if only one argument + if len(arguments) == 1: + serialized_obj = self._serialize_object(arguments[0]) + return self._serialize_to_json_bytes(serialized_obj) + else: + # Multiple arguments - serialize as list + serialized_args = [self._serialize_object(arg) for arg in arguments] + return self._serialize_to_json_bytes(serialized_args) except Exception as e: raise SerializationException(f"Encoding failed: {e}") from e @@ -153,7 +159,9 @@ def _serialize_object(self, obj: Any, depth: int = 0) -> Any: handler = self.registry.get_handler(obj) if handler: try: - return handler(obj) + serialized = handler(obj) + # Recursively serialize the result from the handler + return self._serialize_object(serialized, depth + 1) except Exception as e: if self.strict_validation: raise SerializationException(f"Handler failed for {type(obj).__name__}: {e}") from e @@ -167,9 +175,11 @@ def _serialize_object(self, obj: Any, depth: int = 0) -> Any: def _serialize_to_json_bytes(self, obj: Any) -> bytes: """Use the first available JSON plugin to serialize""" last_error = None - for plugin in self.json_plugins: + + for i, plugin in enumerate(self.json_plugins): try: - return plugin.encode(obj) + result = plugin.encode(obj) + return result except Exception as e: last_error = e continue @@ -178,7 +188,7 @@ def _serialize_to_json_bytes(self, obj: Any) -> bytes: class JsonTransportDecoder: - """JSON Transport Decoder with plugin architecture""" + """JSON Transport Decoder""" def __init__(self, target_type: Union[Type, List[Type]] = None, **kwargs): self.target_type = target_type @@ -218,6 +228,35 @@ def decode(self, data: bytes) -> Any: json_data = self._deserialize_from_json_bytes(data) reconstructed_data = self._reconstruct_objects(json_data) + # CRITICAL FIX: If reconstructed_data is a list with a single item + # and we expect a single target type, extract it + if ( + isinstance(reconstructed_data, list) + and len(reconstructed_data) == 1 + and self.target_type + and not isinstance(self.target_type, list) + ): + single_item = reconstructed_data[0] + + # Check if the single item is already our target type + if isinstance(single_item, self.target_type): + return single_item + # Otherwise continue with normal processing using the extracted item + reconstructed_data = single_item + + # Also handle the case where we have a list target type but receive a list with single target + elif ( + isinstance(reconstructed_data, list) + and len(reconstructed_data) == 1 + and isinstance(self.target_type, list) + and len(self.target_type) == 1 + ): + single_item = reconstructed_data[0] + target_type = self.target_type[0] + + if isinstance(single_item, target_type): + return single_item + if not self.target_type: return reconstructed_data @@ -225,12 +264,23 @@ def decode(self, data: bytes) -> Any: if self.multiple_parameter_mode and hasattr(self, "parameter_wrapper_model"): try: wrapper_instance = self.parameter_wrapper_model(**reconstructed_data) - return tuple(getattr(wrapper_instance, f"param_{i}") for i in range(len(self.parameter_types))) - except Exception: + result = tuple( + getattr(wrapper_instance, f"param_{i}") for i in range(len(self.parameter_types)) + ) + return result + except Exception as e: pass - return self._decode_to_target_type(reconstructed_data, self.parameter_types[0]) + + # For single target type in list, decode to that type + if len(self.parameter_types) > 0: + target_type = self.parameter_types[0] + result = self._decode_to_target_type(reconstructed_data, target_type) + return result + else: + return reconstructed_data else: - return self._decode_to_target_type(reconstructed_data, self.target_type) + result = self._decode_to_target_type(reconstructed_data, self.target_type) + return result except Exception as e: raise DeserializationException(f"Decoding failed: {e}") from e @@ -249,20 +299,36 @@ def _deserialize_from_json_bytes(self, data: bytes) -> Any: def _decode_to_target_type(self, json_data: Any, target_type: Type) -> Any: """Convert JSON data to target type with proper Pydantic handling""" + + # If we already have the right type, return it immediately + if isinstance(json_data, target_type): + return json_data + # Check if target type is a Pydantic model try: from pydantic import BaseModel if isinstance(target_type, type) and issubclass(target_type, BaseModel): - # If json_data is already a Pydantic model instance, return as-is + # If json_data is already a Pydantic model instance of the target type if isinstance(json_data, target_type): return json_data + # If json_data is a dict, try to construct the Pydantic model elif isinstance(json_data, dict): return target_type(**json_data) + + # If json_data is a list with one element, extract it + elif isinstance(json_data, list) and len(json_data) == 1: + return self._decode_to_target_type(json_data[0], target_type) + + # If json_data is a list of dicts, try the first one + elif isinstance(json_data, list) and len(json_data) > 0 and isinstance(json_data[0], dict): + return self._decode_to_target_type(json_data[0], target_type) + except ImportError: pass + # Handle basic types if target_type in (str, int, float, bool, list, dict): return target_type(json_data) @@ -270,43 +336,54 @@ def _decode_to_target_type(self, json_data: Any, target_type: Type) -> Any: def _reconstruct_objects(self, data: Any) -> Any: """Reconstruct special objects from their serialized form""" + if not isinstance(data, dict): if isinstance(data, list): - return [self._reconstruct_objects(item) for item in data] + result = [self._reconstruct_objects(item) for item in data] + return result return data # Handle special serialized objects if "__datetime__" in data: from datetime import datetime - return datetime.fromisoformat(data["__datetime__"]) + dt = datetime.fromisoformat(data["__datetime__"]) + return dt elif "__date__" in data: from datetime import date - return date.fromisoformat(data["__date__"]) + d = date.fromisoformat(data["__date__"]) + return d elif "__time__" in data: from datetime import time - return time.fromisoformat(data["__time__"]) + t = time.fromisoformat(data["__time__"]) + return t elif "__decimal__" in data: from decimal import Decimal - return Decimal(data["__decimal__"]) + dec = Decimal(data["__decimal__"]) + return dec elif "__set__" in data: - return set(self._reconstruct_objects(item) for item in data["__set__"]) + s = set(self._reconstruct_objects(item) for item in data["__set__"]) + return s elif "__frozenset__" in data: - return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) + fs = frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) + return fs elif "__uuid__" in data: from uuid import UUID - return UUID(data["__uuid__"]) + u = UUID(data["__uuid__"]) + return u elif "__path__" in data: from pathlib import Path - return Path(data["__path__"]) + p = Path(data["__path__"]) + return p elif "__pydantic_model__" in data and "__model_data__" in data: # Properly reconstruct Pydantic models - return self._reconstruct_pydantic_model(data) + result = self._reconstruct_pydantic_model(data) + return result elif "__dataclass__" in data: module_name, class_name = data["__dataclass__"].rsplit(".", 1) import importlib @@ -314,17 +391,19 @@ def _reconstruct_objects(self, data: Any) -> Any: module = importlib.import_module(module_name) cls = getattr(module, class_name) fields = self._reconstruct_objects(data["fields"]) - return cls(**fields) - + result = cls(**fields) + return result elif "__enum__" in data: module_name, class_name = data["__enum__"].rsplit(".", 1) import importlib module = importlib.import_module(module_name) cls = getattr(module, class_name) - return cls(data["value"]) + result = cls(data["value"]) + return result else: - return {key: self._reconstruct_objects(value) for key, value in data.items()} + result = {key: self._reconstruct_objects(value) for key, value in data.items()} + return result def _reconstruct_pydantic_model(self, data: dict) -> Any: """Reconstruct a Pydantic model from serialized data""" @@ -345,7 +424,8 @@ def _reconstruct_pydantic_model(self, data: dict) -> Any: reconstructed_data = self._reconstruct_objects(model_data) # Create the model instance - return model_class(**reconstructed_data) + result = model_class(**reconstructed_data) + return result except Exception as e: # If reconstruction fails, return the model data as dict @@ -354,7 +434,7 @@ def _reconstruct_pydantic_model(self, data: dict) -> Any: class JsonTransportCodec: - """Combined JSON transport codec - maintains backward compatibility""" + """JSON transport codec""" def __init__( self, @@ -374,7 +454,7 @@ def __init__( ) self._decoder = JsonTransportDecoder(target_type=return_type, **kwargs) - def encode_parameters(self, *arguments) -> bytes: + def encode_parameters(self, *arguments, parameter_type: list = None) -> bytes: """Encode parameters - supports both positional and keyword args""" return self._encoder.encode(arguments) @@ -391,27 +471,3 @@ def get_decoder(self) -> JsonTransportDecoder: def register_type_provider(self, provider) -> None: """Register custom type provider""" self._encoder.register_type_provider(provider) - - -class JsonTransportService: - """Service for extension loader integration""" - - def __init__(self, parameter_types: List[Type] = None, return_type: Type = None, **kwargs): - self.codec = JsonTransportCodec(parameter_types=parameter_types, return_type=return_type, **kwargs) - - def encode_parameters(self, *args, **kwargs) -> bytes: - """Enhanced to handle both positional and keyword arguments""" - if kwargs: - # Convert keyword args to positional for now - # Could be enhanced to preserve keyword information - combined_args = args + tuple(kwargs.values()) - return self.codec.encode_parameters(*combined_args) - return self.codec.encode_parameters(*args) - - def decode_return_value(self, data: bytes) -> Any: - """Decode return value""" - return self.codec.decode_return_value(data) - - def register_custom_plugin(self, plugin): - """Allow extension loader to register custom plugins""" - self.codec.register_type_provider(plugin) diff --git a/src/dubbo/codec/json_codec/json_transport_plugins.py b/src/dubbo/codec/json_codec/json_transport_codec.py similarity index 83% rename from src/dubbo/codec/json_codec/json_transport_plugins.py rename to src/dubbo/codec/json_codec/json_transport_codec.py index 4c6d936..3c1d05c 100644 --- a/src/dubbo/codec/json_codec/json_transport_plugins.py +++ b/src/dubbo/codec/json_codec/json_transport_codec.py @@ -39,7 +39,7 @@ def can_handle(self, obj: Any) -> bool: class OrJsonPlugin: - """orjson plugin - separate from standard json""" + """orjson plugin independent implementation""" def __init__(self): try: @@ -64,15 +64,19 @@ def can_handle(self, obj: Any) -> bool: return self.available def _default_handler(self, obj): - """Handle types that orjson doesn't support natively""" + """Handle types not supported natively by orjson""" if isinstance(obj, datetime): return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} - elif isinstance(obj, (date, time)): - return {"__date__": obj.isoformat()} if isinstance(obj, date) else {"__time__": obj.isoformat()} + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} elif isinstance(obj, Decimal): return {"__decimal__": str(obj)} - elif isinstance(obj, (set, frozenset)): - return {"__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj)} + elif isinstance(obj, set): + return {"__set__": list(obj)} + elif isinstance(obj, frozenset): + return {"__frozenset__": list(obj)} elif isinstance(obj, UUID): return {"__uuid__": str(obj)} elif isinstance(obj, Path): @@ -81,7 +85,7 @@ def _default_handler(self, obj): class UJsonPlugin: - """ujson plugin - separate from others""" + """ujson plugin implementation""" def __init__(self): try: @@ -106,16 +110,31 @@ def can_handle(self, obj: Any) -> bool: return self.available def _default_handler(self, obj): - """Same as orjson handler""" - return OrJsonPlugin()._default_handler(obj) + """Handle types not supported natively by ujson""" + if isinstance(obj, datetime): + return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, set): + return {"__set__": list(obj)} + elif isinstance(obj, frozenset): + return {"__frozenset__": list(obj)} + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + return {"__fallback__": str(obj), "__type__": type(obj).__name__} -# Type Handler Plugins (properly inherit protocol) class DateTimeHandler: """DateTime handler - implements TypeHandlerPlugin protocol""" def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type in (datetime, date, time) + return isinstance(obj, (datetime, date, time)) def serialize_to_dict(self, obj: Union[datetime, date, time]) -> Dict[str, str]: if isinstance(obj, datetime): diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index 194815c..d224a47 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -14,6 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .protobuf_codec_handler import ProtobufTransportCodec, ProtobufTransportEncoder, ProtobufTransportDecoder +from .protobuf_codec_handler import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder -__all__ = ["ProtobufTransportCodec", "ProtobufTransportEncoder", "ProtobufTransportDecoder"] +__all__ = [ + "ProtobufTransportCodec", + "ProtobufTransportDecoder", + "ProtobufTransportEncoder", +] diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py index 8c75062..b8554a0 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py @@ -14,29 +14,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Type, Protocol, Optional -from abc import ABC, abstractmethod -import json +from typing import Any, Type, Protocol, Dict, Optional from dataclasses import dataclass +import json +from abc import ABC, abstractmethod # Betterproto imports try: import betterproto + HAS_BETTERPROTO = True except ImportError: HAS_BETTERPROTO = False -try: - from pydantic import BaseModel - HAS_PYDANTIC = True -except ImportError: - HAS_PYDANTIC = False -# Reuse your existing JSON type system -from dubbo.codec.json_codec.json_type import ( - TypeProviderFactory, SerializationState, - SerializationException, DeserializationException -) +class SerializationException(Exception): + """Exception raised when encoding or serialization fails.""" + + def __init__(self, message: str, *, cause: Exception = None): + super().__init__(message) + self.cause = cause + + +class DeserializationException(Exception): + """Exception raised when decoding or deserialization fails.""" + + def __init__(self, message: str, *, cause: Exception = None): + super().__init__(message) + self.cause = cause class ProtobufEncodingFunction(Protocol): @@ -50,81 +55,255 @@ def __call__(self, data: bytes) -> Any: ... @dataclass class ProtobufMethodDescriptor: """Protobuf-specific method descriptor for single parameter""" + parameter_type: Type return_type: Type - protobuf_message_type: Optional[Type] = None - use_json_fallback: bool = False + protobuf_message_type: Type = None + + +# Abstract base classes for pluggable architecture +class TypeHandler(ABC): + """Abstract base class for type handlers""" + + @abstractmethod + def is_message(self, obj_type: Type) -> bool: + """Check if the type is a message type""" + pass + + @abstractmethod + def is_message_instance(self, obj: Any) -> bool: + """Check if the object is a message instance""" + pass + + @abstractmethod + def is_compatible(self, obj_type: Type) -> bool: + """Check if the type is compatible with this handler""" + pass + +class EncodingStrategy(ABC): + """Abstract base class for encoding strategies""" -class ProtobufTypeHandler: + @abstractmethod + def can_encode(self, parameter: Any, parameter_type: Type = None) -> bool: + """Check if this strategy can encode the given parameter""" + pass + + @abstractmethod + def encode(self, parameter: Any, parameter_type: Type = None) -> bytes: + """Encode the parameter to bytes""" + pass + + +class DecodingStrategy(ABC): + """Abstract base class for decoding strategies""" + + @abstractmethod + def can_decode(self, data: bytes, target_type: Type) -> bool: + """Check if this strategy can decode to the target type""" + pass + + @abstractmethod + def decode(self, data: bytes, target_type: Type) -> Any: + """Decode the bytes to the target type""" + pass + + +# Concrete implementations +class ProtobufTypeHandler(TypeHandler): """Handles type conversion between Python types and Betterproto""" - @staticmethod - def is_betterproto_message(obj_type: Type) -> bool: - """Check if type is a betterproto message class""" + def is_message(self, obj_type: Type) -> bool: if not HAS_BETTERPROTO: return False try: - return (hasattr(obj_type, '__dataclass_fields__') and - issubclass(obj_type, betterproto.Message)) + return hasattr(obj_type, "__dataclass_fields__") and issubclass(obj_type, betterproto.Message) except (TypeError, AttributeError): return False - @staticmethod - def is_betterproto_message_instance(obj: Any) -> bool: - """Check if object is a betterproto message instance""" + def is_message_instance(self, obj: Any) -> bool: if not HAS_BETTERPROTO: return False - try: - return isinstance(obj, betterproto.Message) - except: - return False + return isinstance(obj, betterproto.Message) + def is_compatible(self, obj_type: Type) -> bool: + return obj_type in (str, int, float, bool, bytes) or self.is_message(obj_type) + + # Static methods for backward compatibility @staticmethod - def is_protobuf_compatible(obj_type: Type) -> bool: - """Check if type can be handled by protobuf""" - return (obj_type in (str, int, float, bool, bytes) or - ProtobufTypeHandler.is_betterproto_message(obj_type)) + def is_betterproto_message(obj_type: Type) -> bool: + handler = ProtobufTypeHandler() + return handler.is_message(obj_type) + + @staticmethod + def is_betterproto_message_instance(obj: Any) -> bool: + handler = ProtobufTypeHandler() + return handler.is_message_instance(obj) @staticmethod - def needs_json_fallback(parameter_type: Type) -> bool: - """Check if we need JSON fallback for this type""" - return not ProtobufTypeHandler.is_protobuf_compatible(parameter_type) + def is_protobuf_compatible(obj_type: Type) -> bool: + handler = ProtobufTypeHandler() + return handler.is_compatible(obj_type) -class ProtobufTransportEncoder: - """Protobuf encoder for single parameters using betterproto""" +class MessageEncodingStrategy(EncodingStrategy): + """Encoding strategy for protobuf messages""" + + def __init__(self, type_handler: TypeHandler): + self.type_handler = type_handler + + def can_encode(self, parameter: Any, parameter_type: Type = None) -> bool: + return self.type_handler.is_message_instance(parameter) or ( + parameter_type and self.type_handler.is_message(parameter_type) + ) + + def encode(self, parameter: Any, parameter_type: Type = None) -> bytes: + if self.type_handler.is_message_instance(parameter): + return bytes(parameter) + + if parameter_type and self.type_handler.is_message(parameter_type): + if isinstance(parameter, parameter_type): + return bytes(parameter) + elif isinstance(parameter, dict): + try: + message = parameter_type().from_dict(parameter) + return bytes(message) + except Exception as e: + raise SerializationException(f"Cannot convert dict to {parameter_type}: {e}") + else: + raise SerializationException(f"Cannot convert {type(parameter)} to {parameter_type}") + + raise SerializationException(f"Cannot encode {type(parameter)} as protobuf message") + + +class PrimitiveEncodingStrategy(EncodingStrategy): + """Encoding strategy for primitive types""" + + def can_encode(self, parameter: Any, parameter_type: Type = None) -> bool: + return isinstance(parameter, (str, int, float, bool, bytes)) + + def encode(self, parameter: Any, parameter_type: Type = None) -> bytes: + try: + json_str = json.dumps({"value": parameter, "type": type(parameter).__name__}) + return json_str.encode("utf-8") + except Exception as e: + raise SerializationException(f"Failed to encode primitive {parameter}: {e}") + + +class MessageDecodingStrategy(DecodingStrategy): + """Decoding strategy for protobuf messages""" + + def __init__(self, type_handler: TypeHandler): + self.type_handler = type_handler + + def can_decode(self, data: bytes, target_type: Type) -> bool: + return self.type_handler.is_message(target_type) + + def decode(self, data: bytes, target_type: Type) -> Any: + try: + return target_type().parse(data) + except Exception as e: + raise DeserializationException(f"Failed to parse betterproto message: {e}") + + +class PrimitiveDecodingStrategy(DecodingStrategy): + """Decoding strategy for primitive types""" + + def can_decode(self, data: bytes, target_type: Type) -> bool: + return target_type in (str, int, float, bool, bytes) + + def decode(self, data: bytes, target_type: Type) -> Any: + try: + json_str = data.decode("utf-8") + parsed = json.loads(json_str) + value = parsed.get("value") + + if target_type == str: + return str(value) + elif target_type == int: + return int(value) + elif target_type == float: + return float(value) + elif target_type == bool: + return bool(value) + elif target_type == bytes: + return bytes(value) if isinstance(value, (list, bytes)) else str(value).encode() + else: + return value + + except Exception as e: + raise DeserializationException(f"Failed to decode primitive: {e}") - def __init__(self, parameter_type: Type = None, **kwargs): + +class StrategyRegistry: + """Registry for managing encoding/decoding strategies""" + + def __init__(self): + self.encoding_strategies: list[EncodingStrategy] = [] + self.decoding_strategies: list[DecodingStrategy] = [] + + def register_encoding_strategy(self, strategy: EncodingStrategy) -> None: + """Register an encoding strategy""" + self.encoding_strategies.append(strategy) + + def register_decoding_strategy(self, strategy: DecodingStrategy) -> None: + """Register a decoding strategy""" + self.decoding_strategies.append(strategy) + + def find_encoding_strategy(self, parameter: Any, parameter_type: Type = None) -> Optional[EncodingStrategy]: + """Find the first strategy that can encode the parameter""" + for strategy in self.encoding_strategies: + if strategy.can_encode(parameter, parameter_type): + return strategy + return None + + def find_decoding_strategy(self, data: bytes, target_type: Type) -> Optional[DecodingStrategy]: + """Find the first strategy that can decode to the target type""" + for strategy in self.decoding_strategies: + if strategy.can_decode(data, target_type): + return strategy + return None + + +class ProtobufTransportEncoder: + """Protobuf encoder for single parameters using pluggable strategies""" + + def __init__( + self, + parameter_type: Type = None, + type_handler: TypeHandler = None, + strategy_registry: StrategyRegistry = None, + **kwargs, + ): if not HAS_BETTERPROTO: raise ImportError("betterproto library is required for ProtobufTransportEncoder") - self.parameter_type = parameter_type - - self.descriptor = ProtobufMethodDescriptor( - parameter_type=parameter_type, - return_type=None, - use_json_fallback=ProtobufTypeHandler.needs_json_fallback(parameter_type) if parameter_type else False - ) + self.descriptor = ProtobufMethodDescriptor(parameter_type=parameter_type, return_type=None) + + self.type_handler = type_handler or ProtobufTypeHandler() + self.strategy_registry = strategy_registry or self._create_default_registry() - if self.descriptor.use_json_fallback: - from dubbo.codec.json_codec.json_codec_handler import JsonTransportEncoder - self.json_fallback_encoder = JsonTransportEncoder([parameter_type], **kwargs) + def _create_default_registry(self) -> StrategyRegistry: + """Create default strategy registry with standard strategies""" + registry = StrategyRegistry() + registry.register_encoding_strategy(MessageEncodingStrategy(self.type_handler)) + registry.register_encoding_strategy(PrimitiveEncodingStrategy()) + return registry - def encode(self, parameter: Any) -> bytes: - """Encode single parameter to bytes""" + def encode(self, parameter: Any, parameter_type: Type) -> bytes: try: if parameter is None: - return b'' + return b"" - # Handle case where parameter is a tuple (common in RPC calls) if isinstance(parameter, tuple): if len(parameter) == 0: - return b'' + return b"" elif len(parameter) == 1: return self._encode_single_parameter(parameter[0]) else: - raise SerializationException(f"Multiple parameters not supported. Got tuple with {len(parameter)} elements, expected 1.") + raise SerializationException( + f"Multiple parameters not supported. Got tuple with {len(parameter)} elements, expected 1." + ) return self._encode_single_parameter(parameter) @@ -132,70 +311,51 @@ def encode(self, parameter: Any) -> bytes: raise SerializationException(f"Protobuf encoding failed: {e}") from e def _encode_single_parameter(self, parameter: Any) -> bytes: - """Encode a single parameter using betterproto""" - # If it's already a betterproto message instance, serialize it - if ProtobufTypeHandler.is_betterproto_message_instance(parameter): - return bytes(parameter) - - # If we have type info and it's a betterproto message type - if self.parameter_type and ProtobufTypeHandler.is_betterproto_message(self.parameter_type): - if isinstance(parameter, self.parameter_type): - return bytes(parameter) - elif isinstance(parameter, dict): - # Convert dict to betterproto message - try: - message = self.parameter_type().from_dict(parameter) - return bytes(message) - except Exception as e: - raise SerializationException(f"Cannot convert dict to {self.parameter_type}: {e}") - else: - raise SerializationException(f"Cannot convert {type(parameter)} to {self.parameter_type}") - - # Handle primitive types by wrapping in a simple message - if isinstance(parameter, (str, int, float, bool, bytes)): - return self._encode_primitive(parameter) + strategy = self.strategy_registry.find_encoding_strategy(parameter, self.parameter_type) + if strategy: + return strategy.encode(parameter, self.parameter_type) - # Use JSON fallback if configured - if self.descriptor.use_json_fallback: - json_data = self.json_fallback_encoder.encode((parameter,)) - return json_data - - raise SerializationException(f"Cannot encode {type(parameter)} as protobuf") + raise SerializationException(f"No encoding strategy found for {type(parameter)}") + # Backward compatibility method def _encode_primitive(self, value: Any) -> bytes: - """Encode primitive values by wrapping them in a simple structure""" - # For primitives, we'll use JSON encoding wrapped in bytes - # This is a simplified approach - in a real implementation you might - # want to define a wrapper protobuf message for primitives - try: - json_str = json.dumps({"value": value, "type": type(value).__name__}) - return json_str.encode('utf-8') - except Exception as e: - raise SerializationException(f"Failed to encode primitive {value}: {e}") + strategy = PrimitiveEncodingStrategy() + return strategy.encode(value) class ProtobufTransportDecoder: - """Protobuf decoder for single parameters using betterproto""" - - def __init__(self, target_type: Type = None, **kwargs): + """Protobuf decoder for single parameters using pluggable strategies""" + + def __init__( + self, + target_type: Type = None, + type_handler: TypeHandler = None, + strategy_registry: StrategyRegistry = None, + **kwargs, + ): if not HAS_BETTERPROTO: raise ImportError("betterproto library is required for ProtobufTransportDecoder") self.target_type = target_type - self.use_json_fallback = ProtobufTypeHandler.needs_json_fallback(target_type) if target_type else False - if self.use_json_fallback: - from dubbo.codec.json_codec.json_codec_handler import JsonTransportDecoder - self.json_fallback_decoder = JsonTransportDecoder(target_type, **kwargs) + # Use provided components or create defaults + self.type_handler = type_handler or ProtobufTypeHandler() + self.strategy_registry = strategy_registry or self._create_default_registry() + + def _create_default_registry(self) -> StrategyRegistry: + """Create default strategy registry with standard strategies""" + registry = StrategyRegistry() + registry.register_decoding_strategy(MessageDecodingStrategy(self.type_handler)) + registry.register_decoding_strategy(PrimitiveDecodingStrategy()) + return registry def decode(self, data: bytes) -> Any: - """Decode bytes to single parameter""" try: if not data: return None if not self.target_type: - return self._decode_without_type_info(data) + raise DeserializationException("No target_type specified for decoding") return self._decode_single_parameter(data, self.target_type) @@ -203,94 +363,63 @@ def decode(self, data: bytes) -> Any: raise DeserializationException(f"Protobuf decoding failed: {e}") from e def _decode_single_parameter(self, data: bytes, target_type: Type) -> Any: - """Decode single parameter using betterproto""" - if ProtobufTypeHandler.is_betterproto_message(target_type): - try: - # Use betterproto's parsing - message_instance = target_type().parse(data) - return message_instance - except Exception as e: - if self.use_json_fallback: - return self.json_fallback_decoder.decode(data) - raise DeserializationException(f"Failed to parse betterproto message: {e}") - - # Handle primitives - elif target_type in (str, int, float, bool, bytes): - return self._decode_primitive(data, target_type) - - # Use JSON fallback - elif self.use_json_fallback: - return self.json_fallback_decoder.decode(data) + strategy = self.strategy_registry.find_decoding_strategy(data, target_type) + if strategy: + return strategy.decode(data, target_type) - else: - raise DeserializationException(f"Cannot decode to {target_type} from protobuf") + raise DeserializationException(f"No decoding strategy found for {target_type}") + # Backward compatibility method def _decode_primitive(self, data: bytes, target_type: Type) -> Any: - """Decode primitive values from their wrapped format""" - try: - json_str = data.decode('utf-8') - parsed = json.loads(json_str) - value = parsed.get("value") - - # Convert to target type if needed - if target_type == str: - return str(value) - elif target_type == int: - return int(value) - elif target_type == float: - return float(value) - elif target_type == bool: - return bool(value) - elif target_type == bytes: - return bytes(value) if isinstance(value, (list, bytes)) else str(value).encode() - else: - return value - - except Exception as e: - raise DeserializationException(f"Failed to decode primitive: {e}") - - def _decode_without_type_info(self, data: bytes) -> Any: - """Decode without type information - try JSON first""" - try: - return json.loads(data.decode('utf-8')) - except: - return data + strategy = PrimitiveDecodingStrategy() + return strategy.decode(data, target_type) class ProtobufTransportCodec: - """Main protobuf codec class for single parameters using betterproto""" - - def __init__(self, parameter_type: Type = None, return_type: Type = None, **kwargs): + """Main protobuf codec class for single parameters""" + + def __init__( + self, + parameter_type: Type = None, + return_type: Type = None, + type_handler: TypeHandler = None, + encoder_registry: StrategyRegistry = None, + decoder_registry: StrategyRegistry = None, + **kwargs, + ): if not HAS_BETTERPROTO: raise ImportError("betterproto library is required for ProtobufTransportCodec") - self.parameter_type = parameter_type - self.return_type = return_type + # Allow sharing registries between encoder and decoder, or use separate ones + shared_registry = encoder_registry or decoder_registry self._encoder = ProtobufTransportEncoder( parameter_type=parameter_type, - **kwargs + type_handler=type_handler, + strategy_registry=encoder_registry or shared_registry, + **kwargs, ) self._decoder = ProtobufTransportDecoder( target_type=return_type, - **kwargs + type_handler=type_handler, + strategy_registry=decoder_registry or shared_registry, + **kwargs, ) def encode_parameter(self, argument: Any) -> bytes: - """Encode single parameter""" return self._encoder.encode(argument) - + def encode_parameters(self, arguments: tuple) -> bytes: - """Legacy method to handle tuple of arguments (for backward compatibility)""" if not arguments: - return b'' + return b"" if len(arguments) == 1: return self._encoder.encode(arguments[0]) else: - raise SerializationException(f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1.") + raise SerializationException( + f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1." + ) def decode_return_value(self, data: bytes) -> Any: - """Decode return value""" return self._decoder.decode(data) def get_encoder(self) -> ProtobufTransportEncoder: @@ -298,8 +427,3 @@ def get_encoder(self) -> ProtobufTransportEncoder: def get_decoder(self) -> ProtobufTransportDecoder: return self._decoder - - -def create_protobuf_codec(**kwargs) -> ProtobufTransportCodec: - """Factory function to create protobuf codec""" - return ProtobufTransportCodec(**kwargs) \ No newline at end of file diff --git a/src/dubbo/protocol/triple/protocol.py b/src/dubbo/protocol/triple/protocol.py index f4aa4d3..f5333cd 100644 --- a/src/dubbo/protocol/triple/protocol.py +++ b/src/dubbo/protocol/triple/protocol.py @@ -30,9 +30,7 @@ from dubbo.remoting import Server, Transporter from dubbo.remoting.aio import constants as aio_constants from dubbo.remoting.aio.http2.protocol import Http2ClientProtocol, Http2ServerProtocol -from dubbo.remoting.aio.http2.stream_handler import ( - StreamClientMultiplexHandler -) +from dubbo.remoting.aio.http2.stream_handler import StreamClientMultiplexHandler from dubbo.url import URL _LOGGER = loggerFactory.get_logger() diff --git a/tests/protobuf/protobuf_test.py b/tests/protobuf/protobuf_test.py index a65492d..95e65bb 100644 --- a/tests/protobuf/protobuf_test.py +++ b/tests/protobuf/protobuf_test.py @@ -15,15 +15,14 @@ # limitations under the License. import pytest -from dubbo.codec.protobuf_codec import ProtobufTransportCodec +from dubbo.codec.protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder + +print(type(ProtobufTransportCodec)) from generated.protobuf_test import GreeterReply, GreeterRequest def test_protobuf_roundtrip_message(): - codec = ProtobufTransportCodec( - parameter_type=GreeterRequest, - return_type=GreeterReply - ) + codec = ProtobufTransportCodec(parameter_type=GreeterRequest, return_type=GreeterReply) # Create a request req = GreeterRequest(name="Alice") @@ -43,26 +42,21 @@ def test_protobuf_roundtrip_message(): def test_protobuf_from_dict(): - codec = ProtobufTransportCodec( - parameter_type=GreeterRequest, - return_type=GreeterReply - ) + codec = ProtobufTransportCodec(parameter_type=GreeterRequest, return_type=GreeterReply) # Dict instead of message instance encoded = codec.encode_parameter({"name": "Bob"}) assert isinstance(encoded, bytes) - # Decode back to message - req = codec._decoder.decode(encoded) # simulate server echo + # To decode back to the parameter type, we need a decoder configured for GreeterRequest + param_decoder = ProtobufTransportDecoder(target_type=GreeterRequest) + req = param_decoder.decode(encoded) assert isinstance(req, GreeterRequest) assert req.name == "Bob" def test_protobuf_primitive_fallback(): - codec = ProtobufTransportCodec( - parameter_type=str, - return_type=str - ) + codec = ProtobufTransportCodec(parameter_type=str, return_type=str) encoded = codec.encode_parameter("simple string") assert isinstance(encoded, bytes) From fbb9d81ca388dff0aef52d028727e0f41f81b842 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 4 Sep 2025 04:40:25 +0000 Subject: [PATCH 18/40] changing the type for organisation requirement.. --- samples/llm/chat_pb2.py | 12 +- samples/llm/main.py | 2 +- src/dubbo/classes.py | 5 +- src/dubbo/client.py | 16 +- src/dubbo/codec/dubbo_codec.py | 36 ++-- src/dubbo/codec/json_codec/__init__.py | 2 +- .../codec/json_codec/json_codec_handler.py | 157 +++++++----------- .../codec/json_codec/json_transport_base.py | 7 +- .../codec/json_codec/json_transport_codec.py | 36 ++-- .../protobuf_codec/protobuf_codec_handler.py | 80 ++++----- src/dubbo/extension/registries.py | 10 +- src/dubbo/proxy/handlers.py | 79 ++------- src/dubbo/remoting/aio/http2/protocol.py | 2 +- tests/json/json_test.py | 12 +- tests/json/json_type_test.py | 11 +- tests/protobuf/protobuf_test.py | 6 +- 16 files changed, 195 insertions(+), 278 deletions(-) diff --git a/samples/llm/chat_pb2.py b/samples/llm/chat_pb2.py index de9488e..90de97f 100644 --- a/samples/llm/chat_pb2.py +++ b/samples/llm/chat_pb2.py @@ -1,13 +1,15 @@ -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: chat.proto # Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database +from google.protobuf import ( + descriptor as _descriptor, + descriptor_pool as _descriptor_pool, + symbol_database as _symbol_database, +) from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -20,7 +22,7 @@ _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "chat_pb2", _globals) -if _descriptor._USE_C_DESCRIPTORS == False: +if not _descriptor._USE_C_DESCRIPTORS: _globals["DESCRIPTOR"]._options = None _globals["DESCRIPTOR"]._serialized_options = b"B\tChatProtoP\001" _globals["_CHATREQUEST"]._serialized_start = 48 diff --git a/samples/llm/main.py b/samples/llm/main.py index e315716..97baaa6 100644 --- a/samples/llm/main.py +++ b/samples/llm/main.py @@ -15,12 +15,12 @@ # limitations under the License. from time import sleep +import chat_pb2 from lmdeploy import GenerationConfig, TurbomindEngineConfig, pipeline from dubbo import Dubbo from dubbo.configs import RegistryConfig, ServiceConfig from dubbo.proxy.handlers import RpcMethodHandler, RpcServiceHandler -import chat_pb2 # the path of a model. It could be one of the following options: # 1. A local directory path of a turbomind model diff --git a/src/dubbo/classes.py b/src/dubbo/classes.py index b3076ea..14728da 100644 --- a/src/dubbo/classes.py +++ b/src/dubbo/classes.py @@ -16,8 +16,9 @@ import abc import threading -from typing import Any, Callable, Optional, Union, Type from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union + from dubbo.types import DeserializingFunction, RpcType, RpcTypes, SerializingFunction __all__ = [ @@ -248,7 +249,7 @@ class ReadWriteStream(ReadStream, WriteStream, abc.ABC): class Codec(ABC): - def __init__(self, model_type: Optional[Type[Any]] = None, **kwargs): + def __init__(self, model_type: Optional[type[Any]] = None, **kwargs): self.model_type = model_type @abstractmethod diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 6d1734f..e0159ce 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -15,10 +15,11 @@ # limitations under the License. import threading -from typing import Optional, List, Type +from typing import Optional from dubbo.bootstrap import Dubbo from dubbo.classes import MethodDescriptor +from dubbo.codec import DubboTransportService from dubbo.configs import ReferenceConfig from dubbo.constants import common_constants from dubbo.extension import extensionLoader @@ -32,7 +33,6 @@ SerializingFunction, ) from dubbo.url import URL -from dubbo.codec import DubboTransportService __all__ = ["Client"] @@ -88,8 +88,8 @@ def _create_rpc_callable( self, rpc_type: str, method_name: str, - params_types: List[Type], - return_type: Type, + params_types: list[type], + return_type: type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, @@ -118,7 +118,7 @@ def _create_rpc_callable( return self._callable(descriptor) - def unary(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def unary(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.UNARY.value, method_name=method_name, @@ -127,7 +127,7 @@ def unary(self, method_name: str, params_types: List[Type], return_type: Type, * **kwargs, ) - def client_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def client_stream(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.CLIENT_STREAM.value, method_name=method_name, @@ -136,7 +136,7 @@ def client_stream(self, method_name: str, params_types: List[Type], return_type: **kwargs, ) - def server_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def server_stream(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.SERVER_STREAM.value, method_name=method_name, @@ -145,7 +145,7 @@ def server_stream(self, method_name: str, params_types: List[Type], return_type: **kwargs, ) - def bi_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def bi_stream(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.BI_STREAM.value, method_name=method_name, diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index 3309c32..41d602b 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -14,10 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Type, Optional, Callable, List, Dict, Tuple -from dataclasses import dataclass import inspect import logging +from dataclasses import dataclass +from typing import Any, Callable, Optional logger = logging.getLogger(__name__) @@ -38,7 +38,7 @@ class MethodDescriptor: function: Callable name: str - parameters: List[ParameterDescriptor] + parameters: list[ParameterDescriptor] return_parameter: ParameterDescriptor documentation: Optional[str] = None @@ -48,27 +48,27 @@ class DubboSerializationService: @staticmethod def create_transport_codec( - transport_type: str = "json", parameter_types: List[Type] = None, return_type: Type = None, **codec_options + transport_type: str = "json", parameter_types: list[type] = None, return_type: type = None, **codec_options ): """Create transport codec with enhanced parameter structure""" try: - from dubbo.extension.extension_loader import ExtensionLoader from dubbo.classes import CodecHelper + from dubbo.extension.extension_loader import ExtensionLoader codec_class = ExtensionLoader().get_extension(CodecHelper.get_class(), transport_type) return codec_class(parameter_types=parameter_types or [], return_type=return_type, **codec_options) except ImportError as e: - logger.error(f"Failed to import required modules: {e}") + logger.error("Failed to import required modules: %s", e) raise except Exception as e: - logger.error(f"Failed to create transport codec: {e}") + logger.error("Failed to create transport codec: %s", e) raise @staticmethod def create_encoder_decoder_pair( - transport_type: str, parameter_types: List[Type] = None, return_type: Type = None, **codec_options - ) -> Tuple[Any, Any]: + transport_type: str, parameter_types: list[type] = None, return_type: type = None, **codec_options + ) -> tuple[Any, Any]: """Create separate encoder and decoder instances""" try: @@ -85,13 +85,13 @@ def create_encoder_decoder_pair( return encoder, decoder except Exception as e: - logger.error(f"Failed to create encoder/decoder pair: {e}") + logger.error("Failed to create encoder/decoder pair: %s", e) raise @staticmethod def create_serialization_functions( - transport_type: str, parameter_types: List[Type] = None, return_type: Type = None, **codec_options - ) -> Tuple[Callable, Callable]: + transport_type: str, parameter_types: list[type] = None, return_type: type = None, **codec_options + ) -> tuple[Callable, Callable]: """Create serializer and deserializer functions for RPC (backward compatibility)""" try: @@ -103,7 +103,7 @@ def serialize_method_parameters(*args) -> bytes: try: return parameter_encoder.encode(args) except Exception as e: - logger.error(f"Failed to serialize parameters: {e}") + logger.error("Failed to serialize parameters: %s", e) raise def deserialize_method_return(data: bytes): @@ -112,21 +112,21 @@ def deserialize_method_return(data: bytes): try: return return_decoder.decode(data) except Exception as e: - logger.error(f"Failed to deserialize return value: {e}") + logger.error("Failed to deserialize return value: %s", e) raise return serialize_method_parameters, deserialize_method_return except Exception as e: - logger.error(f"Failed to create serialization functions: {e}") + logger.error("Failed to create serialization functions: %s", e) raise @staticmethod def create_method_descriptor( func: Callable, method_name: Optional[str] = None, - parameter_types: List[Type] = None, - return_type: Type = None, + parameter_types: list[type] = None, + return_type: type = None, interface: Callable = None, ) -> MethodDescriptor: """Create a method descriptor from function and configuration""" @@ -141,7 +141,7 @@ def create_method_descriptor( try: sig = inspect.signature(target_function) except ValueError as e: - logger.error(f"Cannot inspect signature of {target_function}: {e}") + logger.error("Cannot inspect signature of %s: %s", target_function, e) raise parameters = [] diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py index a11163f..ea68a77 100644 --- a/src/dubbo/codec/json_codec/__init__.py +++ b/src/dubbo/codec/json_codec/__init__.py @@ -14,6 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .json_codec_handler import JsonTransportCodec, JsonTransportEncoder, JsonTransportDecoder +from .json_codec_handler import JsonTransportCodec, JsonTransportDecoder, JsonTransportEncoder __all__ = ["JsonTransportCodec", "JsonTransportDecoder", "JsonTransportEncoder"] diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index 7cc0d3e..c2f0252 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -14,19 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Type, List, Union, Optional -from .json_transport_base import SimpleRegistry, SerializationException, DeserializationException +from typing import Any, Optional + +from .json_transport_base import DeserializationException, SerializationException, SimpleRegistry from .json_transport_codec import ( - StandardJsonPlugin, - OrJsonPlugin, - UJsonPlugin, - DateTimeHandler, - DecimalHandler, CollectionHandler, - SimpleTypeHandler, - PydanticHandler, DataclassHandler, + DateTimeHandler, + DecimalHandler, EnumHandler, + OrJsonPlugin, + PydanticHandler, + SimpleTypeHandler, + StandardJsonPlugin, + UJsonPlugin, ) @@ -34,13 +35,17 @@ class JsonTransportEncoder: """JSON Transport Encoder""" def __init__( - self, parameter_types: List[Type] = None, maximum_depth: int = 100, strict_validation: bool = True, **kwargs + self, + parameter_types: list[type] | None = None, + maximum_depth: int = 100, + strict_validation: bool = True, + **kwargs, ): self.parameter_types = parameter_types or [] self.maximum_depth = maximum_depth self.strict_validation = strict_validation self.registry = SimpleRegistry() - self.json_plugins = [] + self.json_plugins: list[Any] = [] # Setup plugins self._register_default_type_plugins() @@ -83,20 +88,20 @@ def register_type_provider(self, provider): """Register custom type provider for backward compatibility""" self.registry.register_plugin(provider) - def encode(self, arguments: tuple, parameter_type: list = None) -> bytes: + def encode(self, arguments: tuple, parameter_type: list[type] | None = None) -> bytes: """Encode arguments with flexible parameter handling""" try: if not arguments: return self._serialize_to_json_bytes([]) # Handle single parameter case - if len(parameter_type) == 1: + if parameter_type and len(parameter_type) == 1: parameter = arguments[0] serialized_param = self._serialize_object(parameter) return self._serialize_to_json_bytes(serialized_param) # Handle multiple parameters - elif len(parameter_type) > 1: + elif parameter_type and len(parameter_type) > 1: # Try Pydantic wrapper for strong typing pydantic_handler = self._get_pydantic_handler() if pydantic_handler and pydantic_handler.available: @@ -139,9 +144,11 @@ def _serialize_object(self, obj: Any, depth: int = 0) -> Any: if depth > self.maximum_depth: raise SerializationException(f"Maximum depth {self.maximum_depth} exceeded") + # Handle primitives if obj is None or isinstance(obj, (bool, int, float, str)): return obj + # Handle collections if isinstance(obj, (list, tuple)): return [self._serialize_object(item, depth + 1) for item in obj] @@ -176,10 +183,9 @@ def _serialize_to_json_bytes(self, obj: Any) -> bytes: """Use the first available JSON plugin to serialize""" last_error = None - for i, plugin in enumerate(self.json_plugins): + for plugin in self.json_plugins: try: - result = plugin.encode(obj) - return result + return plugin.encode(obj) except Exception as e: last_error = e continue @@ -190,9 +196,9 @@ def _serialize_to_json_bytes(self, obj: Any) -> bytes: class JsonTransportDecoder: """JSON Transport Decoder""" - def __init__(self, target_type: Union[Type, List[Type]] = None, **kwargs): + def __init__(self, target_type: type | list[type] | None = None, **kwargs): self.target_type = target_type - self.json_plugins = [] + self.json_plugins: list[Any] = [] self._setup_json_deserializer_plugins() # Handle multiple parameter types @@ -228,8 +234,7 @@ def decode(self, data: bytes) -> Any: json_data = self._deserialize_from_json_bytes(data) reconstructed_data = self._reconstruct_objects(json_data) - # CRITICAL FIX: If reconstructed_data is a list with a single item - # and we expect a single target type, extract it + # Handle single-item list unpacking if target type expects one value if ( isinstance(reconstructed_data, list) and len(reconstructed_data) == 1 @@ -237,14 +242,11 @@ def decode(self, data: bytes) -> Any: and not isinstance(self.target_type, list) ): single_item = reconstructed_data[0] - - # Check if the single item is already our target type if isinstance(single_item, self.target_type): return single_item - # Otherwise continue with normal processing using the extracted item reconstructed_data = single_item - # Also handle the case where we have a list target type but receive a list with single target + # Handle [single] target_type inside a list elif ( isinstance(reconstructed_data, list) and len(reconstructed_data) == 1 @@ -253,7 +255,6 @@ def decode(self, data: bytes) -> Any: ): single_item = reconstructed_data[0] target_type = self.target_type[0] - if isinstance(single_item, target_type): return single_item @@ -264,23 +265,16 @@ def decode(self, data: bytes) -> Any: if self.multiple_parameter_mode and hasattr(self, "parameter_wrapper_model"): try: wrapper_instance = self.parameter_wrapper_model(**reconstructed_data) - result = tuple( - getattr(wrapper_instance, f"param_{i}") for i in range(len(self.parameter_types)) - ) - return result - except Exception as e: + return tuple(getattr(wrapper_instance, f"param_{i}") for i in range(len(self.parameter_types))) + except Exception: pass - # For single target type in list, decode to that type - if len(self.parameter_types) > 0: - target_type = self.parameter_types[0] - result = self._decode_to_target_type(reconstructed_data, target_type) - return result - else: - return reconstructed_data + # Decode to first type if available + if self.parameter_types: + return self._decode_to_target_type(reconstructed_data, self.parameter_types[0]) + return reconstructed_data else: - result = self._decode_to_target_type(reconstructed_data, self.target_type) - return result + return self._decode_to_target_type(reconstructed_data, self.target_type) except Exception as e: raise DeserializationException(f"Decoding failed: {e}") from e @@ -297,38 +291,30 @@ def _deserialize_from_json_bytes(self, data: bytes) -> Any: raise DeserializationException(f"All JSON plugins failed. Last error: {last_error}") - def _decode_to_target_type(self, json_data: Any, target_type: Type) -> Any: + def _decode_to_target_type(self, json_data: Any, target_type: type) -> Any: """Convert JSON data to target type with proper Pydantic handling""" - # If we already have the right type, return it immediately if isinstance(json_data, target_type): return json_data - # Check if target type is a Pydantic model + # Special handling for Pydantic models try: from pydantic import BaseModel if isinstance(target_type, type) and issubclass(target_type, BaseModel): - # If json_data is already a Pydantic model instance of the target type if isinstance(json_data, target_type): return json_data - - # If json_data is a dict, try to construct the Pydantic model elif isinstance(json_data, dict): return target_type(**json_data) - - # If json_data is a list with one element, extract it elif isinstance(json_data, list) and len(json_data) == 1: return self._decode_to_target_type(json_data[0], target_type) - - # If json_data is a list of dicts, try the first one - elif isinstance(json_data, list) and len(json_data) > 0 and isinstance(json_data[0], dict): + elif isinstance(json_data, list) and isinstance(json_data[0], dict): return self._decode_to_target_type(json_data[0], target_type) except ImportError: pass - # Handle basic types + # Handle built-in simple types if target_type in (str, int, float, bool, list, dict): return target_type(json_data) @@ -339,51 +325,40 @@ def _reconstruct_objects(self, data: Any) -> Any: if not isinstance(data, dict): if isinstance(data, list): - result = [self._reconstruct_objects(item) for item in data] - return result + return [self._reconstruct_objects(item) for item in data] return data # Handle special serialized objects if "__datetime__" in data: from datetime import datetime - dt = datetime.fromisoformat(data["__datetime__"]) - return dt + return datetime.fromisoformat(data["__datetime__"]) elif "__date__" in data: from datetime import date - d = date.fromisoformat(data["__date__"]) - return d + return date.fromisoformat(data["__date__"]) elif "__time__" in data: from datetime import time - t = time.fromisoformat(data["__time__"]) - return t + return time.fromisoformat(data["__time__"]) elif "__decimal__" in data: from decimal import Decimal - dec = Decimal(data["__decimal__"]) - return dec + return Decimal(data["__decimal__"]) elif "__set__" in data: - s = set(self._reconstruct_objects(item) for item in data["__set__"]) - return s + return set(self._reconstruct_objects(item) for item in data["__set__"]) elif "__frozenset__" in data: - fs = frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) - return fs + return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) elif "__uuid__" in data: from uuid import UUID - u = UUID(data["__uuid__"]) - return u + return UUID(data["__uuid__"]) elif "__path__" in data: from pathlib import Path - p = Path(data["__path__"]) - return p + return Path(data["__path__"]) elif "__pydantic_model__" in data and "__model_data__" in data: - # Properly reconstruct Pydantic models - result = self._reconstruct_pydantic_model(data) - return result + return self._reconstruct_pydantic_model(data) elif "__dataclass__" in data: module_name, class_name = data["__dataclass__"].rsplit(".", 1) import importlib @@ -391,45 +366,32 @@ def _reconstruct_objects(self, data: Any) -> Any: module = importlib.import_module(module_name) cls = getattr(module, class_name) fields = self._reconstruct_objects(data["fields"]) - result = cls(**fields) - return result + return cls(**fields) elif "__enum__" in data: module_name, class_name = data["__enum__"].rsplit(".", 1) import importlib module = importlib.import_module(module_name) cls = getattr(module, class_name) - result = cls(data["value"]) - return result + return cls(data["value"]) else: - result = {key: self._reconstruct_objects(value) for key, value in data.items()} - return result + return {key: self._reconstruct_objects(value) for key, value in data.items()} def _reconstruct_pydantic_model(self, data: dict) -> Any: """Reconstruct a Pydantic model from serialized data""" try: model_path = data["__pydantic_model__"] model_data = data["__model_data__"] - - # Try to import and reconstruct the model module_name, class_name = model_path.rsplit(".", 1) - # Import the module import importlib module = importlib.import_module(module_name) model_class = getattr(module, class_name) - # Recursively reconstruct nested objects in model_data reconstructed_data = self._reconstruct_objects(model_data) - - # Create the model instance - result = model_class(**reconstructed_data) - return result - - except Exception as e: - # If reconstruction fails, return the model data as dict - # This allows the target type conversion to handle it + return model_class(**reconstructed_data) + except Exception: return self._reconstruct_objects(data.get("__model_data__", {})) @@ -438,8 +400,8 @@ class JsonTransportCodec: def __init__( self, - parameter_types: List[Type] = None, - return_type: Type = None, + parameter_types: list[type] | None = None, + return_type: type | None = None, maximum_depth: int = 100, strict_validation: bool = True, **kwargs, @@ -450,13 +412,16 @@ def __init__( self.strict_validation = strict_validation self._encoder = JsonTransportEncoder( - parameter_types=parameter_types, maximum_depth=maximum_depth, strict_validation=strict_validation, **kwargs + parameter_types=parameter_types, + maximum_depth=maximum_depth, + strict_validation=strict_validation, + **kwargs, ) self._decoder = JsonTransportDecoder(target_type=return_type, **kwargs) - def encode_parameters(self, *arguments, parameter_type: list = None) -> bytes: + def encode_parameters(self, *arguments, parameter_type: list[type] | None = None) -> bytes: """Encode parameters - supports both positional and keyword args""" - return self._encoder.encode(arguments) + return self._encoder.encode(arguments, parameter_type=parameter_type) def decode_return_value(self, data: bytes) -> Any: """Decode return value""" diff --git a/src/dubbo/codec/json_codec/json_transport_base.py b/src/dubbo/codec/json_codec/json_transport_base.py index d1ead52..165eb6c 100644 --- a/src/dubbo/codec/json_codec/json_transport_base.py +++ b/src/dubbo/codec/json_codec/json_transport_base.py @@ -1,4 +1,3 @@ -# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. @@ -14,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, List, Dict, Protocol, Callable +from typing import Any, Callable, Protocol class JsonSerializerPlugin(Protocol): @@ -37,8 +36,8 @@ class SimpleRegistry: def __init__(self): # Simple dict mapping: type -> handler function - self.type_handlers: Dict[type, Callable] = {} - self.plugins: List[TypeHandlerPlugin] = [] + self.type_handlers: dict[type, Callable] = {} + self.plugins: list[TypeHandlerPlugin] = [] def register_type_handler(self, obj_type: type, handler: Callable): """Register a simple type handler function""" diff --git a/src/dubbo/codec/json_codec/json_transport_codec.py b/src/dubbo/codec/json_codec/json_transport_codec.py index 3c1d05c..2c2999e 100644 --- a/src/dubbo/codec/json_codec/json_transport_codec.py +++ b/src/dubbo/codec/json_codec/json_transport_codec.py @@ -8,21 +8,21 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. import json -from typing import Any, Type, List, Union, Dict -from datetime import datetime, date, time +from dataclasses import asdict, is_dataclass +from datetime import date, datetime, time from decimal import Decimal +from enum import Enum from pathlib import Path +from typing import Any, Union from uuid import UUID -from enum import Enum -from dataclasses import is_dataclass, asdict -from typing import Any, Dict, List class StandardJsonPlugin: @@ -136,7 +136,7 @@ class DateTimeHandler: def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return isinstance(obj, (datetime, date, time)) - def serialize_to_dict(self, obj: Union[datetime, date, time]) -> Dict[str, str]: + def serialize_to_dict(self, obj: Union[datetime, date, time]) -> dict[str, str]: if isinstance(obj, datetime): return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} elif isinstance(obj, date): @@ -149,7 +149,7 @@ class DecimalHandler: def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return obj_type is Decimal - def serialize_to_dict(self, obj: Decimal) -> Dict[str, str]: + def serialize_to_dict(self, obj: Decimal) -> dict[str, str]: return {"__decimal__": str(obj)} @@ -157,7 +157,7 @@ class CollectionHandler: def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return obj_type in (set, frozenset) - def serialize_to_dict(self, obj: Union[set, frozenset]) -> Dict[str, List]: + def serialize_to_dict(self, obj: Union[set, frozenset]) -> dict[str, list]: return {"__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj)} @@ -167,7 +167,7 @@ class EnumHandler: def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return isinstance(obj, Enum) - def serialize_to_dict(self, obj: Enum) -> Dict[str, Any]: + def serialize_to_dict(self, obj: Enum) -> dict[str, Any]: return {"__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "value": obj.value} @@ -177,7 +177,7 @@ class DataclassHandler: def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return is_dataclass(obj) - def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + def serialize_to_dict(self, obj: Any) -> dict[str, Any]: return {"__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "fields": asdict(obj)} @@ -185,7 +185,7 @@ class SimpleTypeHandler: def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return obj_type in (UUID, Path) or isinstance(obj, Path) - def serialize_to_dict(self, obj: Union[UUID, Path]) -> Dict[str, str]: + def serialize_to_dict(self, obj: Union[UUID, Path]) -> dict[str, str]: if isinstance(obj, UUID): return {"__uuid__": str(obj)} elif isinstance(obj, Path): @@ -208,7 +208,7 @@ def __init__(self): def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return self.available and isinstance(obj, self.BaseModel) - def serialize_to_dict(self, obj) -> Dict[str, Any]: + def serialize_to_dict(self, obj) -> dict[str, Any]: if hasattr(obj, "model_dump"): return { "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", @@ -219,7 +219,7 @@ def serialize_to_dict(self, obj) -> Dict[str, Any]: "__model_data__": obj.dict(), } - def create_parameter_model(self, parameter_types: List[Type]): + def create_parameter_model(self, parameter_types: list[type]): """Enhanced parameter handling for both positional and keyword args""" if not self.available: return None diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py index b8554a0..d1c4774 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py @@ -8,16 +8,16 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" +# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Type, Protocol, Dict, Optional -from dataclasses import dataclass import json from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Optional, Protocol # Betterproto imports try: @@ -56,9 +56,9 @@ def __call__(self, data: bytes) -> Any: ... class ProtobufMethodDescriptor: """Protobuf-specific method descriptor for single parameter""" - parameter_type: Type - return_type: Type - protobuf_message_type: Type = None + parameter_type: type + return_type: type + protobuf_message_type: type | None = None # Abstract base classes for pluggable architecture @@ -66,7 +66,7 @@ class TypeHandler(ABC): """Abstract base class for type handlers""" @abstractmethod - def is_message(self, obj_type: Type) -> bool: + def is_message(self, obj_type: type) -> bool: """Check if the type is a message type""" pass @@ -76,7 +76,7 @@ def is_message_instance(self, obj: Any) -> bool: pass @abstractmethod - def is_compatible(self, obj_type: Type) -> bool: + def is_compatible(self, obj_type: type) -> bool: """Check if the type is compatible with this handler""" pass @@ -85,12 +85,12 @@ class EncodingStrategy(ABC): """Abstract base class for encoding strategies""" @abstractmethod - def can_encode(self, parameter: Any, parameter_type: Type = None) -> bool: + def can_encode(self, parameter: Any, parameter_type: type | None = None) -> bool: """Check if this strategy can encode the given parameter""" pass @abstractmethod - def encode(self, parameter: Any, parameter_type: Type = None) -> bytes: + def encode(self, parameter: Any, parameter_type: type | None = None) -> bytes: """Encode the parameter to bytes""" pass @@ -99,12 +99,12 @@ class DecodingStrategy(ABC): """Abstract base class for decoding strategies""" @abstractmethod - def can_decode(self, data: bytes, target_type: Type) -> bool: + def can_decode(self, data: bytes, target_type: type) -> bool: """Check if this strategy can decode to the target type""" pass @abstractmethod - def decode(self, data: bytes, target_type: Type) -> Any: + def decode(self, data: bytes, target_type: type) -> Any: """Decode the bytes to the target type""" pass @@ -113,7 +113,7 @@ def decode(self, data: bytes, target_type: Type) -> Any: class ProtobufTypeHandler(TypeHandler): """Handles type conversion between Python types and Betterproto""" - def is_message(self, obj_type: Type) -> bool: + def is_message(self, obj_type: type) -> bool: if not HAS_BETTERPROTO: return False try: @@ -126,12 +126,12 @@ def is_message_instance(self, obj: Any) -> bool: return False return isinstance(obj, betterproto.Message) - def is_compatible(self, obj_type: Type) -> bool: + def is_compatible(self, obj_type: type) -> bool: return obj_type in (str, int, float, bool, bytes) or self.is_message(obj_type) # Static methods for backward compatibility @staticmethod - def is_betterproto_message(obj_type: Type) -> bool: + def is_betterproto_message(obj_type: type) -> bool: handler = ProtobufTypeHandler() return handler.is_message(obj_type) @@ -141,7 +141,7 @@ def is_betterproto_message_instance(obj: Any) -> bool: return handler.is_message_instance(obj) @staticmethod - def is_protobuf_compatible(obj_type: Type) -> bool: + def is_protobuf_compatible(obj_type: type) -> bool: handler = ProtobufTypeHandler() return handler.is_compatible(obj_type) @@ -152,12 +152,12 @@ class MessageEncodingStrategy(EncodingStrategy): def __init__(self, type_handler: TypeHandler): self.type_handler = type_handler - def can_encode(self, parameter: Any, parameter_type: Type = None) -> bool: + def can_encode(self, parameter: Any, parameter_type: type | None = None) -> bool: return self.type_handler.is_message_instance(parameter) or ( parameter_type and self.type_handler.is_message(parameter_type) ) - def encode(self, parameter: Any, parameter_type: Type = None) -> bytes: + def encode(self, parameter: Any, parameter_type: type | None = None) -> bytes: if self.type_handler.is_message_instance(parameter): return bytes(parameter) @@ -179,10 +179,10 @@ def encode(self, parameter: Any, parameter_type: Type = None) -> bytes: class PrimitiveEncodingStrategy(EncodingStrategy): """Encoding strategy for primitive types""" - def can_encode(self, parameter: Any, parameter_type: Type = None) -> bool: + def can_encode(self, parameter: Any, parameter_type: type | None = None) -> bool: return isinstance(parameter, (str, int, float, bool, bytes)) - def encode(self, parameter: Any, parameter_type: Type = None) -> bytes: + def encode(self, parameter: Any, parameter_type: type | None = None) -> bytes: try: json_str = json.dumps({"value": parameter, "type": type(parameter).__name__}) return json_str.encode("utf-8") @@ -196,10 +196,10 @@ class MessageDecodingStrategy(DecodingStrategy): def __init__(self, type_handler: TypeHandler): self.type_handler = type_handler - def can_decode(self, data: bytes, target_type: Type) -> bool: + def can_decode(self, data: bytes, target_type: type) -> bool: return self.type_handler.is_message(target_type) - def decode(self, data: bytes, target_type: Type) -> Any: + def decode(self, data: bytes, target_type: type) -> Any: try: return target_type().parse(data) except Exception as e: @@ -209,24 +209,24 @@ def decode(self, data: bytes, target_type: Type) -> Any: class PrimitiveDecodingStrategy(DecodingStrategy): """Decoding strategy for primitive types""" - def can_decode(self, data: bytes, target_type: Type) -> bool: + def can_decode(self, data: bytes, target_type: type) -> bool: return target_type in (str, int, float, bool, bytes) - def decode(self, data: bytes, target_type: Type) -> Any: + def decode(self, data: bytes, target_type: type) -> Any: try: json_str = data.decode("utf-8") parsed = json.loads(json_str) value = parsed.get("value") - if target_type == str: + if target_type is str: return str(value) - elif target_type == int: + elif target_type is int: return int(value) - elif target_type == float: + elif target_type is float: return float(value) - elif target_type == bool: + elif target_type is bool: return bool(value) - elif target_type == bytes: + elif target_type is bytes: return bytes(value) if isinstance(value, (list, bytes)) else str(value).encode() else: return value @@ -250,14 +250,14 @@ def register_decoding_strategy(self, strategy: DecodingStrategy) -> None: """Register a decoding strategy""" self.decoding_strategies.append(strategy) - def find_encoding_strategy(self, parameter: Any, parameter_type: Type = None) -> Optional[EncodingStrategy]: + def find_encoding_strategy(self, parameter: Any, parameter_type: type | None = None) -> Optional[EncodingStrategy]: """Find the first strategy that can encode the parameter""" for strategy in self.encoding_strategies: if strategy.can_encode(parameter, parameter_type): return strategy return None - def find_decoding_strategy(self, data: bytes, target_type: Type) -> Optional[DecodingStrategy]: + def find_decoding_strategy(self, data: bytes, target_type: type) -> Optional[DecodingStrategy]: """Find the first strategy that can decode to the target type""" for strategy in self.decoding_strategies: if strategy.can_decode(data, target_type): @@ -270,7 +270,7 @@ class ProtobufTransportEncoder: def __init__( self, - parameter_type: Type = None, + parameter_type: type = None, type_handler: TypeHandler = None, strategy_registry: StrategyRegistry = None, **kwargs, @@ -290,7 +290,7 @@ def _create_default_registry(self) -> StrategyRegistry: registry.register_encoding_strategy(PrimitiveEncodingStrategy()) return registry - def encode(self, parameter: Any, parameter_type: Type) -> bytes: + def encode(self, parameter: Any, parameter_type: type) -> bytes: try: if parameter is None: return b"" @@ -328,7 +328,7 @@ class ProtobufTransportDecoder: def __init__( self, - target_type: Type = None, + target_type: type = None, type_handler: TypeHandler = None, strategy_registry: StrategyRegistry = None, **kwargs, @@ -362,7 +362,7 @@ def decode(self, data: bytes) -> Any: except Exception as e: raise DeserializationException(f"Protobuf decoding failed: {e}") from e - def _decode_single_parameter(self, data: bytes, target_type: Type) -> Any: + def _decode_single_parameter(self, data: bytes, target_type: type) -> Any: strategy = self.strategy_registry.find_decoding_strategy(data, target_type) if strategy: return strategy.decode(data, target_type) @@ -370,7 +370,7 @@ def _decode_single_parameter(self, data: bytes, target_type: Type) -> Any: raise DeserializationException(f"No decoding strategy found for {target_type}") # Backward compatibility method - def _decode_primitive(self, data: bytes, target_type: Type) -> Any: + def _decode_primitive(self, data: bytes, target_type: type) -> Any: strategy = PrimitiveDecodingStrategy() return strategy.decode(data, target_type) @@ -380,8 +380,8 @@ class ProtobufTransportCodec: def __init__( self, - parameter_type: Type = None, - return_type: Type = None, + parameter_type: type = None, + return_type: type = None, type_handler: TypeHandler = None, encoder_registry: StrategyRegistry = None, decoder_registry: StrategyRegistry = None, diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index ee52650..bed9d0c 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -15,22 +15,18 @@ # limitations under the License. import importlib -from typing import Any from dataclasses import dataclass +from typing import Any -from dubbo.classes import SingletonBase -from dubbo.extension import registries as registries_module +from dubbo.classes import Codec, SingletonBase # Import all the required interface classes -from dataclasses import dataclass -from typing import Any - from dubbo.cluster import LoadBalance from dubbo.compression import Compressor, Decompressor +from dubbo.extension import registries as registries_module from dubbo.protocol import Protocol from dubbo.registry import RegistryFactory from dubbo.remoting import Transporter -from dubbo.classes import Codec class ExtensionError(Exception): diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index 7c72516..ce6f879 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -15,7 +15,8 @@ # limitations under the License. import inspect -from typing import Callable, Optional, List, Type, Any, get_type_hints +from typing import Any, Callable, Optional, get_type_hints + from dubbo.classes import MethodDescriptor from dubbo.codec import DubboTransportService from dubbo.types import ( @@ -76,7 +77,7 @@ def _infer_types_from_method(cls, method: Callable) -> tuple: :param method: the method to analyze :type method: Callable :return: tuple of method name, parameter types, return type - :rtype: Tuple[str, List[Type], Type] + :rtype: Tuple[str, list[type], type] """ try: type_hints = get_type_hints(method) @@ -105,8 +106,8 @@ def _create_method_descriptor( cls, method: Callable, method_name: str, - params_types: List[Type], - return_type: Type, + params_types: list[type], + return_type: type, rpc_type: str, codec: Optional[str] = None, param_encoder: Optional[DeserializingFunction] = None, @@ -151,8 +152,8 @@ def unary( cls, method: Callable, method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, + params_types: Optional[list[type]] = None, + return_type: Optional[type] = None, codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, @@ -160,15 +161,6 @@ def unary( ) -> "RpcMethodHandler": """ Register a unary RPC method handler - :param method: the callable function - :param method_name: RPC method name - :param params_types: input types - :param return_type: output type - :param codec: serialization codec - :param request_deserializer: custom deserializer - :param response_serializer: custom serializer - :return: RpcMethodHandler instance - :rtype: RpcMethodHandler """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name @@ -195,8 +187,8 @@ def client_stream( cls, method: Callable, method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, + params_types: Optional[list[type]] = None, + return_type: Optional[type] = None, codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, @@ -204,15 +196,6 @@ def client_stream( ) -> "RpcMethodHandler": """ Register a client-streaming RPC method handler - :param method: the callable function - :param method_name: RPC method name - :param params_types: input types - :param return_type: output type - :param codec: serialization codec - :param request_deserializer: custom deserializer - :param response_serializer: custom serializer - :return: RpcMethodHandler instance - :rtype: RpcMethodHandler """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name @@ -239,8 +222,8 @@ def server_stream( cls, method: Callable, method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, + params_types: Optional[list[type]] = None, + return_type: Optional[type] = None, codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, @@ -248,15 +231,6 @@ def server_stream( ) -> "RpcMethodHandler": """ Register a server-streaming RPC method handler - :param method: the callable function - :param method_name: RPC method name - :param params_types: input types - :param return_type: output type - :param codec: serialization codec - :param request_deserializer: custom deserializer - :param response_serializer: custom serializer - :return: RpcMethodHandler instance - :rtype: RpcMethodHandler """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name @@ -283,8 +257,8 @@ def bi_stream( cls, method: Callable, method_name: Optional[str] = None, - params_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, + params_types: Optional[list[type]] = None, + return_type: Optional[type] = None, codec: Optional[str] = None, request_deserializer: Optional[DeserializingFunction] = None, response_serializer: Optional[SerializingFunction] = None, @@ -292,15 +266,6 @@ def bi_stream( ) -> "RpcMethodHandler": """ Register a bidirectional streaming RPC method handler - :param method: the callable function - :param method_name: RPC method name - :param params_types: input types - :param return_type: output type - :param codec: serialization codec - :param request_deserializer: custom deserializer - :param response_serializer: custom serializer - :return: RpcMethodHandler instance - :rtype: RpcMethodHandler """ inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) resolved_method_name = method_name or inferred_name @@ -330,13 +295,9 @@ class RpcServiceHandler: __slots__ = ["_service_name", "_method_handlers"] - def __init__(self, service_name: str, method_handlers: List[RpcMethodHandler]): + def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): """ Initialize the RpcServiceHandler - :param service_name: the name of the service. - :type service_name: str - :param method_handlers: list of RpcMethodHandler instances - :type method_handlers: List[RpcMethodHandler] """ self._service_name = service_name self._method_handlers: dict[str, RpcMethodHandler] = {} @@ -347,18 +308,10 @@ def __init__(self, service_name: str, method_handlers: List[RpcMethodHandler]): @property def service_name(self) -> str: - """ - Get the service name - :return: the service name - :rtype: str - """ + """Get the service name""" return self._service_name @property def method_handlers(self) -> dict[str, RpcMethodHandler]: - """ - Get the registered RPC method handlers - :return: mapping of method names to handlers - :rtype: Dict[str, RpcMethodHandler] - """ + """Get the registered RPC method handlers""" return self._method_handlers diff --git a/src/dubbo/remoting/aio/http2/protocol.py b/src/dubbo/remoting/aio/http2/protocol.py index a762d0a..c99876d 100644 --- a/src/dubbo/remoting/aio/http2/protocol.py +++ b/src/dubbo/remoting/aio/http2/protocol.py @@ -26,7 +26,6 @@ from dubbo.loggers import loggerFactory from dubbo.remoting.aio import ConnectionStateListener, EmptyConnectionStateListener, constants as h2_constants from dubbo.remoting.aio.exceptions import ProtocolError -from dubbo.remoting.aio.http2.stream_handler import StreamServerMultiplexHandler from dubbo.remoting.aio.http2.controllers import RemoteFlowController from dubbo.remoting.aio.http2.frames import ( DataFrame, @@ -39,6 +38,7 @@ ) from dubbo.remoting.aio.http2.registries import Http2FrameType from dubbo.remoting.aio.http2.stream import Http2Stream +from dubbo.remoting.aio.http2.stream_handler import StreamServerMultiplexHandler from dubbo.remoting.aio.http2.utils import Http2EventUtils from dubbo.url import URL from dubbo.utils import EventHelper, FutureHelper diff --git a/tests/json/json_test.py b/tests/json/json_test.py index 9fe588f..5bc7b9d 100644 --- a/tests/json/json_test.py +++ b/tests/json/json_test.py @@ -14,14 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from pathlib import Path -from uuid import UUID -from decimal import Decimal -from datetime import datetime, date, time from dataclasses import dataclass +from datetime import date, datetime, time +from decimal import Decimal from enum import Enum +from pathlib import Path +from uuid import UUID + +import pytest from pydantic import BaseModel + from dubbo.codec.json_codec import JsonTransportCodec diff --git a/tests/json/json_type_test.py b/tests/json/json_type_test.py index 3f146d0..8aeedf4 100644 --- a/tests/json/json_type_test.py +++ b/tests/json/json_type_test.py @@ -14,13 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from pathlib import Path -from uuid import UUID -from decimal import Decimal -from datetime import datetime, date, time from dataclasses import dataclass +from datetime import date, datetime, time +from decimal import Decimal from enum import Enum +from pathlib import Path +from uuid import UUID + +import pytest from pydantic import BaseModel from dubbo.codec.json_codec.json_codec_handler import JsonTransportCodec diff --git a/tests/protobuf/protobuf_test.py b/tests/protobuf/protobuf_test.py index 95e65bb..6c331c5 100644 --- a/tests/protobuf/protobuf_test.py +++ b/tests/protobuf/protobuf_test.py @@ -14,12 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from dubbo.codec.protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder - -print(type(ProtobufTransportCodec)) from generated.protobuf_test import GreeterReply, GreeterRequest +from dubbo.codec.protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder + def test_protobuf_roundtrip_message(): codec = ProtobufTransportCodec(parameter_type=GreeterRequest, return_type=GreeterReply) From 4f914b577d8bab9acd8b0d3ab5c77788fcee6ef8 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 4 Sep 2025 11:36:21 +0000 Subject: [PATCH 19/40] fixing the minor bugs --- src/dubbo/client.py | 20 +-- src/dubbo/codec/__init__.py | 4 +- src/dubbo/codec/dubbo_codec.py | 52 +++++--- .../codec/json_codec/json_transport_base.py | 10 +- .../codec/json_codec/json_transport_codec.py | 15 ++- .../protobuf_codec/protobuf_codec_handler.py | 120 +++++++----------- src/dubbo/proxy/handlers.py | 4 +- 7 files changed, 111 insertions(+), 114 deletions(-) diff --git a/src/dubbo/client.py b/src/dubbo/client.py index e0159ce..229de59 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -15,11 +15,10 @@ # limitations under the License. import threading -from typing import Optional +from typing import Optional, List, Type from dubbo.bootstrap import Dubbo from dubbo.classes import MethodDescriptor -from dubbo.codec import DubboTransportService from dubbo.configs import ReferenceConfig from dubbo.constants import common_constants from dubbo.extension import extensionLoader @@ -33,6 +32,7 @@ SerializingFunction, ) from dubbo.url import URL +from dubbo.codec import DubboSerializationService __all__ = ["Client"] @@ -88,8 +88,8 @@ def _create_rpc_callable( self, rpc_type: str, method_name: str, - params_types: list[type], - return_type: type, + params_types: List[Type], + return_type: Type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, @@ -97,12 +97,13 @@ def _create_rpc_callable( """ Create RPC callable with the specified type. """ + print("2", params_types) # Determine serializers if request_serializer and response_deserializer: req_ser = request_serializer res_deser = response_deserializer else: - req_ser, res_deser = DubboTransportService.create_serialization_functions( + req_ser, res_deser = DubboSerializationService.create_serialization_functions( codec or "json", parameter_types=params_types, return_type=return_type, @@ -118,7 +119,8 @@ def _create_rpc_callable( return self._callable(descriptor) - def unary(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: + def unary(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + print("1", params_types) return self._create_rpc_callable( rpc_type=RpcTypes.UNARY.value, method_name=method_name, @@ -127,7 +129,7 @@ def unary(self, method_name: str, params_types: list[type], return_type: type, * **kwargs, ) - def client_stream(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: + def client_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.CLIENT_STREAM.value, method_name=method_name, @@ -136,7 +138,7 @@ def client_stream(self, method_name: str, params_types: list[type], return_type: **kwargs, ) - def server_stream(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: + def server_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.SERVER_STREAM.value, method_name=method_name, @@ -145,7 +147,7 @@ def server_stream(self, method_name: str, params_types: list[type], return_type: **kwargs, ) - def bi_stream(self, method_name: str, params_types: list[type], return_type: type, **kwargs) -> RpcCallable: + def bi_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: return self._create_rpc_callable( rpc_type=RpcTypes.BI_STREAM.value, method_name=method_name, diff --git a/src/dubbo/codec/__init__.py b/src/dubbo/codec/__init__.py index 72e6f3b..c3061bc 100644 --- a/src/dubbo/codec/__init__.py +++ b/src/dubbo/codec/__init__.py @@ -14,6 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .dubbo_codec import DubboTransportService +from .dubbo_codec import DubboSerializationService -__all__ = ["DubboTransportService"] +__all__ = ["DubboSerializationService"] diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index 41d602b..f127a16 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -24,7 +24,7 @@ @dataclass class ParameterDescriptor: - """Detailed information about a method parameter""" + """Information about a method parameter""" name: str annotation: Any @@ -34,7 +34,7 @@ class ParameterDescriptor: @dataclass class MethodDescriptor: - """Complete method descriptor with all necessary information""" + """Method descriptor with function details""" function: Callable name: str @@ -44,13 +44,16 @@ class MethodDescriptor: class DubboSerializationService: - """Dubbo serialization service with robust type handling""" + """Dubbo serialization service with type handling""" @staticmethod def create_transport_codec( - transport_type: str = "json", parameter_types: list[type] = None, return_type: type = None, **codec_options + transport_type: str = "json", + parameter_types: Optional[list[type]] = None, + return_type: Optional[type] = None, + **codec_options, ): - """Create transport codec with enhanced parameter structure""" + """Create transport codec""" try: from dubbo.classes import CodecHelper @@ -67,13 +70,19 @@ def create_transport_codec( @staticmethod def create_encoder_decoder_pair( - transport_type: str, parameter_types: list[type] = None, return_type: type = None, **codec_options + transport_type: str, + parameter_types: Optional[list[type]] = None, + return_type: Optional[type] = None, + **codec_options, ) -> tuple[Any, Any]: - """Create separate encoder and decoder instances""" + """Create encoder and decoder instances""" try: codec_instance = DubboSerializationService.create_transport_codec( - transport_type=transport_type, parameter_types=parameter_types, return_type=return_type, **codec_options + transport_type=transport_type, + parameter_types=parameter_types, + return_type=return_type, + **codec_options, ) encoder = codec_instance.get_encoder() @@ -90,13 +99,19 @@ def create_encoder_decoder_pair( @staticmethod def create_serialization_functions( - transport_type: str, parameter_types: list[type] = None, return_type: type = None, **codec_options + transport_type: str, + parameter_types: Optional[list[type]] = None, + return_type: Optional[type] = None, + **codec_options, ) -> tuple[Callable, Callable]: - """Create serializer and deserializer functions for RPC (backward compatibility)""" + """Create serializer and deserializer functions""" try: parameter_encoder, return_decoder = DubboSerializationService.create_encoder_decoder_pair( - transport_type=transport_type, parameter_types=parameter_types, return_type=return_type, **codec_options + transport_type=transport_type, + parameter_types=parameter_types, + return_type=return_type, + **codec_options, ) def serialize_method_parameters(*args) -> bytes: @@ -125,9 +140,9 @@ def deserialize_method_return(data: bytes): def create_method_descriptor( func: Callable, method_name: Optional[str] = None, - parameter_types: list[type] = None, - return_type: type = None, - interface: Callable = None, + parameter_types: Optional[list[type]] = None, + return_type: Optional[type] = None, + interface: Optional[Callable[..., Any]] = None, ) -> MethodDescriptor: """Create a method descriptor from function and configuration""" @@ -135,7 +150,7 @@ def create_method_descriptor( raise TypeError("func must be callable") # Use interface signature if provided, otherwise use func signature - target_function = interface if interface else func + target_function = interface if interface is not None else func name = method_name or target_function.__name__ try: @@ -166,14 +181,17 @@ def create_method_descriptor( parameters.append( ParameterDescriptor( - name=param_name, annotation=param_type, is_required=is_required, default_value=default_value + name=param_name, + annotation=param_type, + is_required=is_required, + default_value=default_value, ) ) param_index += 1 # Resolve return type - if return_type: + if return_type is not None: resolved_return_type = return_type elif sig.return_annotation != inspect.Signature.empty: resolved_return_type = sig.return_annotation diff --git a/src/dubbo/codec/json_codec/json_transport_base.py b/src/dubbo/codec/json_codec/json_transport_base.py index 165eb6c..a62be03 100644 --- a/src/dubbo/codec/json_codec/json_transport_base.py +++ b/src/dubbo/codec/json_codec/json_transport_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable, Protocol +from typing import Any, Callable, Protocol, Optional class JsonSerializerPlugin(Protocol): @@ -36,18 +36,18 @@ class SimpleRegistry: def __init__(self): # Simple dict mapping: type -> handler function - self.type_handlers: dict[type, Callable] = {} + self.type_handlers: dict[type, Callable[..., Any]] = {} self.plugins: list[TypeHandlerPlugin] = [] - def register_type_handler(self, obj_type: type, handler: Callable): + def register_type_handler(self, obj_type: type, handler: Callable[..., Any]) -> None: """Register a simple type handler function""" self.type_handlers[obj_type] = handler - def register_plugin(self, plugin: TypeHandlerPlugin): + def register_plugin(self, plugin: TypeHandlerPlugin) -> None: """Register a plugin""" self.plugins.append(plugin) - def get_handler(self, obj: Any) -> Callable: + def get_handler(self, obj: Any) -> Optional[Callable[..., Any]]: """Get handler for object - check dict first, then plugins""" obj_type = type(obj) if obj_type in self.type_handlers: diff --git a/src/dubbo/codec/json_codec/json_transport_codec.py b/src/dubbo/codec/json_codec/json_transport_codec.py index 2c2999e..075efef 100644 --- a/src/dubbo/codec/json_codec/json_transport_codec.py +++ b/src/dubbo/codec/json_codec/json_transport_codec.py @@ -23,10 +23,11 @@ from pathlib import Path from typing import Any, Union from uuid import UUID +from typing import Optional class StandardJsonPlugin: - """Standard library JSON plugin""" + """Standard library JSON codec""" def encode(self, obj: Any) -> bytes: return json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8") @@ -39,7 +40,7 @@ def can_handle(self, obj: Any) -> bool: class OrJsonPlugin: - """orjson plugin independent implementation""" + """orjson Codec independent implementation""" def __init__(self): try: @@ -131,12 +132,12 @@ def _default_handler(self, obj): class DateTimeHandler: - """DateTime handler - implements TypeHandlerPlugin protocol""" + """DateTime handler - implements TypeHandler Codec""" def can_serialize_type(self, obj: Any, obj_type: type) -> bool: return isinstance(obj, (datetime, date, time)) - def serialize_to_dict(self, obj: Union[datetime, date, time]) -> dict[str, str]: + def serialize_to_dict(self, obj: Union[datetime, date, time]) -> dict[str, str | None]: if isinstance(obj, datetime): return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} elif isinstance(obj, date): @@ -193,7 +194,7 @@ def serialize_to_dict(self, obj: Union[UUID, Path]) -> dict[str, str]: class PydanticHandler: - """Separate Pydantic plugin with enhanced features""" + """Pydantic codec for handling advance serialization""" def __init__(self): try: @@ -219,9 +220,9 @@ def serialize_to_dict(self, obj) -> dict[str, Any]: "__model_data__": obj.dict(), } - def create_parameter_model(self, parameter_types: list[type]): + def create_parameter_model(self, parameter_types: Optional[list[type]] = None): """Enhanced parameter handling for both positional and keyword args""" - if not self.available: + if not self.available or parameter_types is None: return None model_fields = {} diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py index d1c4774..53acd35 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py @@ -31,7 +31,7 @@ class SerializationException(Exception): """Exception raised when encoding or serialization fails.""" - def __init__(self, message: str, *, cause: Exception = None): + def __init__(self, message: str, *, cause: Optional[Exception] = None): super().__init__(message) self.cause = cause @@ -39,7 +39,7 @@ def __init__(self, message: str, *, cause: Exception = None): class DeserializationException(Exception): """Exception raised when decoding or deserialization fails.""" - def __init__(self, message: str, *, cause: Exception = None): + def __init__(self, message: str, *, cause: Optional[Exception] = None): super().__init__(message) self.cause = cause @@ -56,9 +56,9 @@ def __call__(self, data: bytes) -> Any: ... class ProtobufMethodDescriptor: """Protobuf-specific method descriptor for single parameter""" - parameter_type: type - return_type: type - protobuf_message_type: type | None = None + parameter_type: Optional[type] + return_type: Optional[type] + protobuf_message_type: Optional[type] = None # Abstract base classes for pluggable architecture @@ -66,47 +66,33 @@ class TypeHandler(ABC): """Abstract base class for type handlers""" @abstractmethod - def is_message(self, obj_type: type) -> bool: - """Check if the type is a message type""" - pass + def is_message(self, obj_type: type) -> bool: ... @abstractmethod - def is_message_instance(self, obj: Any) -> bool: - """Check if the object is a message instance""" - pass + def is_message_instance(self, obj: Any) -> bool: ... @abstractmethod - def is_compatible(self, obj_type: type) -> bool: - """Check if the type is compatible with this handler""" - pass + def is_compatible(self, obj_type: type) -> bool: ... class EncodingStrategy(ABC): """Abstract base class for encoding strategies""" @abstractmethod - def can_encode(self, parameter: Any, parameter_type: type | None = None) -> bool: - """Check if this strategy can encode the given parameter""" - pass + def can_encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bool: ... @abstractmethod - def encode(self, parameter: Any, parameter_type: type | None = None) -> bytes: - """Encode the parameter to bytes""" - pass + def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: ... class DecodingStrategy(ABC): """Abstract base class for decoding strategies""" @abstractmethod - def can_decode(self, data: bytes, target_type: type) -> bool: - """Check if this strategy can decode to the target type""" - pass + def can_decode(self, data: bytes, target_type: type) -> bool: ... @abstractmethod - def decode(self, data: bytes, target_type: type) -> Any: - """Decode the bytes to the target type""" - pass + def decode(self, data: bytes, target_type: type) -> Any: ... # Concrete implementations @@ -152,12 +138,12 @@ class MessageEncodingStrategy(EncodingStrategy): def __init__(self, type_handler: TypeHandler): self.type_handler = type_handler - def can_encode(self, parameter: Any, parameter_type: type | None = None) -> bool: + def can_encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bool: return self.type_handler.is_message_instance(parameter) or ( - parameter_type and self.type_handler.is_message(parameter_type) + parameter_type is not None and self.type_handler.is_message(parameter_type) ) - def encode(self, parameter: Any, parameter_type: type | None = None) -> bytes: + def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: if self.type_handler.is_message_instance(parameter): return bytes(parameter) @@ -179,10 +165,10 @@ def encode(self, parameter: Any, parameter_type: type | None = None) -> bytes: class PrimitiveEncodingStrategy(EncodingStrategy): """Encoding strategy for primitive types""" - def can_encode(self, parameter: Any, parameter_type: type | None = None) -> bool: + def can_encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bool: return isinstance(parameter, (str, int, float, bool, bytes)) - def encode(self, parameter: Any, parameter_type: type | None = None) -> bytes: + def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: try: json_str = json.dumps({"value": parameter, "type": type(parameter).__name__}) return json_str.encode("utf-8") @@ -227,7 +213,12 @@ def decode(self, data: bytes, target_type: type) -> Any: elif target_type is bool: return bool(value) elif target_type is bytes: - return bytes(value) if isinstance(value, (list, bytes)) else str(value).encode() + if isinstance(value, bytes): + return value + elif isinstance(value, list): + return bytes(value) + else: + return str(value).encode() else: return value @@ -243,22 +234,20 @@ def __init__(self): self.decoding_strategies: list[DecodingStrategy] = [] def register_encoding_strategy(self, strategy: EncodingStrategy) -> None: - """Register an encoding strategy""" self.encoding_strategies.append(strategy) def register_decoding_strategy(self, strategy: DecodingStrategy) -> None: - """Register a decoding strategy""" self.decoding_strategies.append(strategy) - def find_encoding_strategy(self, parameter: Any, parameter_type: type | None = None) -> Optional[EncodingStrategy]: - """Find the first strategy that can encode the parameter""" + def find_encoding_strategy( + self, parameter: Any, parameter_type: Optional[type] = None + ) -> Optional[EncodingStrategy]: for strategy in self.encoding_strategies: if strategy.can_encode(parameter, parameter_type): return strategy return None def find_decoding_strategy(self, data: bytes, target_type: type) -> Optional[DecodingStrategy]: - """Find the first strategy that can decode to the target type""" for strategy in self.decoding_strategies: if strategy.can_decode(data, target_type): return strategy @@ -270,9 +259,9 @@ class ProtobufTransportEncoder: def __init__( self, - parameter_type: type = None, - type_handler: TypeHandler = None, - strategy_registry: StrategyRegistry = None, + parameter_type: Optional[type] = None, + type_handler: Optional[TypeHandler] = None, + strategy_registry: Optional[StrategyRegistry] = None, **kwargs, ): if not HAS_BETTERPROTO: @@ -284,40 +273,39 @@ def __init__( self.strategy_registry = strategy_registry or self._create_default_registry() def _create_default_registry(self) -> StrategyRegistry: - """Create default strategy registry with standard strategies""" registry = StrategyRegistry() registry.register_encoding_strategy(MessageEncodingStrategy(self.type_handler)) registry.register_encoding_strategy(PrimitiveEncodingStrategy()) return registry - def encode(self, parameter: Any, parameter_type: type) -> bytes: + def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: try: if parameter is None: return b"" + effective_type = parameter_type or self.parameter_type + if isinstance(parameter, tuple): if len(parameter) == 0: return b"" elif len(parameter) == 1: - return self._encode_single_parameter(parameter[0]) + return self._encode_single_parameter(parameter[0], effective_type) else: raise SerializationException( f"Multiple parameters not supported. Got tuple with {len(parameter)} elements, expected 1." ) - return self._encode_single_parameter(parameter) + return self._encode_single_parameter(parameter, effective_type) except Exception as e: raise SerializationException(f"Protobuf encoding failed: {e}") from e - def _encode_single_parameter(self, parameter: Any) -> bytes: - strategy = self.strategy_registry.find_encoding_strategy(parameter, self.parameter_type) + def _encode_single_parameter(self, parameter: Any, parameter_type: Optional[type]) -> bytes: + strategy = self.strategy_registry.find_encoding_strategy(parameter, parameter_type) if strategy: - return strategy.encode(parameter, self.parameter_type) - + return strategy.encode(parameter, parameter_type) raise SerializationException(f"No encoding strategy found for {type(parameter)}") - # Backward compatibility method def _encode_primitive(self, value: Any) -> bytes: strategy = PrimitiveEncodingStrategy() return strategy.encode(value) @@ -328,22 +316,19 @@ class ProtobufTransportDecoder: def __init__( self, - target_type: type = None, - type_handler: TypeHandler = None, - strategy_registry: StrategyRegistry = None, + target_type: Optional[type] = None, + type_handler: Optional[TypeHandler] = None, + strategy_registry: Optional[StrategyRegistry] = None, **kwargs, ): if not HAS_BETTERPROTO: raise ImportError("betterproto library is required for ProtobufTransportDecoder") self.target_type = target_type - - # Use provided components or create defaults self.type_handler = type_handler or ProtobufTypeHandler() self.strategy_registry = strategy_registry or self._create_default_registry() def _create_default_registry(self) -> StrategyRegistry: - """Create default strategy registry with standard strategies""" registry = StrategyRegistry() registry.register_decoding_strategy(MessageDecodingStrategy(self.type_handler)) registry.register_decoding_strategy(PrimitiveDecodingStrategy()) @@ -353,12 +338,9 @@ def decode(self, data: bytes) -> Any: try: if not data: return None - if not self.target_type: raise DeserializationException("No target_type specified for decoding") - return self._decode_single_parameter(data, self.target_type) - except Exception as e: raise DeserializationException(f"Protobuf decoding failed: {e}") from e @@ -366,10 +348,8 @@ def _decode_single_parameter(self, data: bytes, target_type: type) -> Any: strategy = self.strategy_registry.find_decoding_strategy(data, target_type) if strategy: return strategy.decode(data, target_type) - raise DeserializationException(f"No decoding strategy found for {target_type}") - # Backward compatibility method def _decode_primitive(self, data: bytes, target_type: type) -> Any: strategy = PrimitiveDecodingStrategy() return strategy.decode(data, target_type) @@ -380,17 +360,16 @@ class ProtobufTransportCodec: def __init__( self, - parameter_type: type = None, - return_type: type = None, - type_handler: TypeHandler = None, - encoder_registry: StrategyRegistry = None, - decoder_registry: StrategyRegistry = None, + parameter_type: Optional[type] = None, + return_type: Optional[type] = None, + type_handler: Optional[TypeHandler] = None, + encoder_registry: Optional[StrategyRegistry] = None, + decoder_registry: Optional[StrategyRegistry] = None, **kwargs, ): if not HAS_BETTERPROTO: raise ImportError("betterproto library is required for ProtobufTransportCodec") - # Allow sharing registries between encoder and decoder, or use separate ones shared_registry = encoder_registry or decoder_registry self._encoder = ProtobufTransportEncoder( @@ -407,17 +386,14 @@ def __init__( ) def encode_parameter(self, argument: Any) -> bytes: - return self._encoder.encode(argument) + return self._encoder.encode(argument, self._encoder.parameter_type) def encode_parameters(self, arguments: tuple) -> bytes: if not arguments: return b"" if len(arguments) == 1: - return self._encoder.encode(arguments[0]) - else: - raise SerializationException( - f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1." - ) + return self._encoder.encode(arguments[0], self._encoder.parameter_type) + raise SerializationException(f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1.") def decode_return_value(self, data: bytes) -> Any: return self._decoder.decode(data) diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index ce6f879..140cc31 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -18,7 +18,7 @@ from typing import Any, Callable, Optional, get_type_hints from dubbo.classes import MethodDescriptor -from dubbo.codec import DubboTransportService +from dubbo.codec import DubboSerializationService from dubbo.types import ( DeserializingFunction, RpcTypes, @@ -68,7 +68,7 @@ def get_codec(**kwargs) -> tuple: :return: serializer and deserializer functions :rtype: Tuple[SerializingFunction, DeserializingFunction] """ - return DubboTransportService.create_serialization_functions(**kwargs) + return DubboSerializationService.create_serialization_functions(**kwargs) @classmethod def _infer_types_from_method(cls, method: Callable) -> tuple: From e12f69e0102b0b762649095337d4c1b470851b05 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 4 Sep 2025 11:46:07 +0000 Subject: [PATCH 20/40] remove some debug statement --- src/dubbo/client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 229de59..64140df 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -120,7 +120,6 @@ def _create_rpc_callable( return self._callable(descriptor) def unary(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: - print("1", params_types) return self._create_rpc_callable( rpc_type=RpcTypes.UNARY.value, method_name=method_name, From 3424ccba3880b9bd81d7dabf11750b2086e64097 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 4 Sep 2025 21:16:29 +0000 Subject: [PATCH 21/40] clean the code and make it easy to understand --- src/dubbo/client.py | 3 +- src/dubbo/codec/json_codec/__init__.py | 31 +- src/dubbo/codec/json_codec/_interfaces.py | 98 ++++ .../codec/json_codec/collections_handler.py | 57 +++ .../codec/json_codec/dataclass_handler.py | 55 +++ .../codec/json_codec/datetime_handler.py | 61 +++ src/dubbo/codec/json_codec/decimal_handler.py | 55 +++ src/dubbo/codec/json_codec/enum_handler.py | 55 +++ src/dubbo/codec/json_codec/json_codec.py | 130 +++++ .../codec/json_codec/json_codec_handler.py | 467 +++++++----------- .../codec/json_codec/json_transport_base.py | 71 --- .../codec/json_codec/json_transport_codec.py | 231 --------- src/dubbo/codec/json_codec/orjson_codec.py | 104 ++++ .../codec/json_codec/pydantic_handler.py | 90 ++++ .../codec/json_codec/simple_types_handler.py | 60 +++ src/dubbo/codec/json_codec/standard_json.py | 65 +++ src/dubbo/codec/json_codec/ujson_codec.py | 104 ++++ src/dubbo/extension/registries.py | 115 ++--- 18 files changed, 1189 insertions(+), 663 deletions(-) create mode 100644 src/dubbo/codec/json_codec/_interfaces.py create mode 100644 src/dubbo/codec/json_codec/collections_handler.py create mode 100644 src/dubbo/codec/json_codec/dataclass_handler.py create mode 100644 src/dubbo/codec/json_codec/datetime_handler.py create mode 100644 src/dubbo/codec/json_codec/decimal_handler.py create mode 100644 src/dubbo/codec/json_codec/enum_handler.py create mode 100644 src/dubbo/codec/json_codec/json_codec.py delete mode 100644 src/dubbo/codec/json_codec/json_transport_base.py delete mode 100644 src/dubbo/codec/json_codec/json_transport_codec.py create mode 100644 src/dubbo/codec/json_codec/orjson_codec.py create mode 100644 src/dubbo/codec/json_codec/pydantic_handler.py create mode 100644 src/dubbo/codec/json_codec/simple_types_handler.py create mode 100644 src/dubbo/codec/json_codec/standard_json.py create mode 100644 src/dubbo/codec/json_codec/ujson_codec.py diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 64140df..4cbfb56 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -97,14 +97,13 @@ def _create_rpc_callable( """ Create RPC callable with the specified type. """ - print("2", params_types) # Determine serializers if request_serializer and response_deserializer: req_ser = request_serializer res_deser = response_deserializer else: req_ser, res_deser = DubboSerializationService.create_serialization_functions( - codec or "json", + codec, parameter_types=params_types, return_type=return_type, ) diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py index ea68a77..55458d5 100644 --- a/src/dubbo/codec/json_codec/__init__.py +++ b/src/dubbo/codec/json_codec/__init__.py @@ -14,6 +14,33 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .json_codec_handler import JsonTransportCodec, JsonTransportDecoder, JsonTransportEncoder +from ._interfaces import JsonCodec, TypeHandler +from .standard_json import StandardJsonCodec +from .orjson_codec import OrJsonCodec +from .ujson_codec import UJsonCodec +from .datetime_handler import DateTimeHandler +from .pydantic_handler import PydanticHandler +from .collections_handler import CollectionHandler +from .decimal_handler import DecimalHandler +from .simple_types_handler import SimpleTypesHandler +from .enum_handler import EnumHandler +from .dataclass_handler import DataclassHandler +from .json_codec_handler import JsonTransportCodec +from .json_codec import JsonTransportCodecBridge -__all__ = ["JsonTransportCodec", "JsonTransportDecoder", "JsonTransportEncoder"] +__all__ = [ + "JsonCodec", + "TypeHandler", + "StandardJsonCodec", + "OrJsonCodec", + "UJsonCodec", + "DateTimeHandler", + "PydanticHandler", + "CollectionHandler", + "DecimalHandler", + "SimpleTypesHandler", + "EnumHandler", + "DataclassHandler", + "JsonTransportCodec", + "JsonTransportCodecBridge", +] diff --git a/src/dubbo/codec/json_codec/_interfaces.py b/src/dubbo/codec/json_codec/_interfaces.py new file mode 100644 index 0000000..4a3a36e --- /dev/null +++ b/src/dubbo/codec/json_codec/_interfaces.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Dict + +__all__ = ["JsonCodec", "TypeHandler"] + + +class JsonCodec(abc.ABC): + """ + The JSON codec interface for encoding and decoding objects to/from JSON bytes. + """ + + @abc.abstractmethod + def encode(self, obj: Any) -> bytes: + """ + Encode an object to JSON bytes. + + :param obj: The object to encode. + :type obj: Any + :return: The encoded JSON bytes. + :rtype: bytes + """ + raise NotImplementedError() + + @abc.abstractmethod + def decode(self, data: bytes) -> Any: + """ + Decode JSON bytes to an object. + + :param data: The JSON bytes to decode. + :type data: bytes + :return: The decoded object. + :rtype: Any + """ + raise NotImplementedError() + + @abc.abstractmethod + def can_handle(self, obj: Any) -> bool: + """ + Check if this codec can handle the given object. + + :param obj: The object to check. + :type obj: Any + :return: True if this codec can handle the object. + :rtype: bool + """ + raise NotImplementedError() + + +class TypeHandler(abc.ABC): + """ + Base interface for all type-specific serializers. + + Each type handler should implement: + - can_serialize_type: determine if the object can be serialized + - serialize_to_dict: return a dict representation of the object + """ + + @abc.abstractmethod + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Returns True if this handler can serialize the given object/type. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if this handler can serialize the object. + :rtype: bool + """ + raise NotImplementedError() + + @abc.abstractmethod + def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + """ + Serialize the object into a dictionary representation. + + :param obj: The object to serialize. + :type obj: Any + :return: The dictionary representation of the object. + :rtype: Dict[str, Any] + """ + raise NotImplementedError() diff --git a/src/dubbo/codec/json_codec/collections_handler.py b/src/dubbo/codec/json_codec/collections_handler.py new file mode 100644 index 0000000..ff18764 --- /dev/null +++ b/src/dubbo/codec/json_codec/collections_handler.py @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Union + +from dubbo.codec.json_codec import TypeHandler + +__all__ = ["CollectionHandler"] + + +class CollectionHandler(TypeHandler): + """ + Type handler for set and frozenset collections. + + Serializes sets and frozensets to list format with type markers + for proper reconstruction. + """ + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Check if this handler can serialize collection types. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if object is set or frozenset. + :rtype: bool + """ + return obj_type in (set, frozenset) + + def serialize_to_dict(self, obj: Union[set, frozenset]) -> Dict[str, list]: + """ + Serialize set/frozenset to dictionary representation. + + :param obj: The collection to serialize. + :type obj: Union[set, frozenset] + :return: Dictionary representation with type marker. + :rtype: Dict[str, list] + """ + if isinstance(obj, frozenset): + return {"__frozenset__": list(obj)} + else: + return {"__set__": list(obj)} diff --git a/src/dubbo/codec/json_codec/dataclass_handler.py b/src/dubbo/codec/json_codec/dataclass_handler.py new file mode 100644 index 0000000..23fd226 --- /dev/null +++ b/src/dubbo/codec/json_codec/dataclass_handler.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, is_dataclass +from typing import Any, Dict + +from dubbo.codec.json_codec import TypeHandler + +__all__ = ["DataclassHandler"] + + +class DataclassHandler(TypeHandler): + """ + Type handler for dataclass objects. + + Serializes dataclass instances with module path and field data + for proper reconstruction. + """ + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Check if this handler can serialize dataclass types. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if object is a dataclass instance. + :rtype: bool + """ + return is_dataclass(obj) + + def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + """ + Serialize dataclass to dictionary representation. + + :param obj: The dataclass to serialize. + :type obj: Any + :return: Dictionary with class path and field data. + :rtype: Dict[str, Any] + """ + return {"__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "fields": asdict(obj)} diff --git a/src/dubbo/codec/json_codec/datetime_handler.py b/src/dubbo/codec/json_codec/datetime_handler.py new file mode 100644 index 0000000..51eb416 --- /dev/null +++ b/src/dubbo/codec/json_codec/datetime_handler.py @@ -0,0 +1,61 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime, date, time +from typing import Any, Dict, Union + +from dubbo.codec.json_codec import TypeHandler + +__all__ = ["DateTimeHandler"] + + +class DateTimeHandler(TypeHandler): + """ + Type handler for datetime, date, and time objects. + + Serializes datetime objects to ISO format with timezone information. + """ + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Check if this handler can serialize datetime-related types. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if object is datetime, date, or time. + :rtype: bool + """ + return isinstance(obj, (datetime, date, time)) + + def serialize_to_dict(self, obj: Union[datetime, date, time]) -> Dict[str, Any]: + """ + Serialize datetime objects to dictionary representation. + + :param obj: The datetime object to serialize. + :type obj: Union[datetime, date, time] + :return: Dictionary representation with type markers. + :rtype: Dict[str, Any] + """ + if isinstance(obj, datetime): + return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + else: + raise ValueError(f"Unsupported datetime type: {type(obj)}") diff --git a/src/dubbo/codec/json_codec/decimal_handler.py b/src/dubbo/codec/json_codec/decimal_handler.py new file mode 100644 index 0000000..3ea87b1 --- /dev/null +++ b/src/dubbo/codec/json_codec/decimal_handler.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from decimal import Decimal +from typing import Any, Dict + +from dubbo.codec.json_codec import TypeHandler + +__all__ = ["DecimalHandler"] + + +class DecimalHandler(TypeHandler): + """ + Type handler for Decimal objects. + + Serializes Decimal objects to string representation + for precision preservation. + """ + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Check if this handler can serialize Decimal types. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if object is Decimal. + :rtype: bool + """ + return obj_type is Decimal + + def serialize_to_dict(self, obj: Decimal) -> Dict[str, str]: + """ + Serialize Decimal to dictionary representation. + + :param obj: The Decimal to serialize. + :type obj: Decimal + :return: Dictionary representation with string value. + :rtype: Dict[str, str] + """ + return {"__decimal__": str(obj)} diff --git a/src/dubbo/codec/json_codec/enum_handler.py b/src/dubbo/codec/json_codec/enum_handler.py new file mode 100644 index 0000000..2cd4a6b --- /dev/null +++ b/src/dubbo/codec/json_codec/enum_handler.py @@ -0,0 +1,55 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from typing import Any, Dict + +from dubbo.codec.json_codec import TypeHandler + +__all__ = ["EnumHandler"] + + +class EnumHandler(TypeHandler): + """ + Type handler for Enum objects. + + Serializes Enum instances with module path and value + for proper reconstruction. + """ + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Check if this handler can serialize Enum types. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if object is an Enum instance. + :rtype: bool + """ + return isinstance(obj, Enum) + + def serialize_to_dict(self, obj: Enum) -> Dict[str, Any]: + """ + Serialize Enum to dictionary representation. + + :param obj: The Enum to serialize. + :type obj: Enum + :return: Dictionary with enum class path and value. + :rtype: Dict[str, Any] + """ + return {"__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "value": obj.value} diff --git a/src/dubbo/codec/json_codec/json_codec.py b/src/dubbo/codec/json_codec/json_codec.py new file mode 100644 index 0000000..f881e42 --- /dev/null +++ b/src/dubbo/codec/json_codec/json_codec.py @@ -0,0 +1,130 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, List, Optional, Type +from dubbo.codec.json_codec import JsonTransportCodec + +__all__ = ["JsonTransportCodecBridge", "JsonParameterEncoder", "JsonReturnDecoder"] + + +class JsonParameterEncoder: + """ + Parameter encoder wrapper for JsonTransportCodec. + """ + + def __init__(self, codec: JsonTransportCodec): + self._codec = codec + + def encode(self, arguments: tuple) -> bytes: + """ + Encode method parameters. + + :param arguments: The method arguments to encode. + :type arguments: tuple + :return: Encoded parameter bytes. + :rtype: bytes + """ + return self._codec.encode_parameters(*arguments) + + +class JsonReturnDecoder: + """ + Return value decoder wrapper for JsonTransportCodec. + """ + + def __init__(self, codec: JsonTransportCodec): + self._codec = codec + + def decode(self, data: bytes) -> Any: + """ + Decode method return value. + + :param data: The bytes to decode. + :type data: bytes + :return: Decoded return value. + :rtype: Any + """ + return self._codec.decode_return_value(data) + + +class JsonTransportCodecBridge: + """ + Bridge class that adapts JsonTransportCodec to work with DubboSerializationService. + + This maintains compatibility with the existing extension loader system while + using the clean new codec architecture internally. + """ + + def __init__( + self, + parameter_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, + maximum_depth: int = 100, + strict_validation: bool = True, + **kwargs, + ): + """ + Initialize the codec bridge. + + :param parameter_types: List of parameter types for the method. + :param return_type: Return type for the method. + :param maximum_depth: Maximum serialization depth. + :param strict_validation: Whether to use strict validation. + """ + self._codec = JsonTransportCodec( + parameter_types=parameter_types, + return_type=return_type, + maximum_depth=maximum_depth, + strict_validation=strict_validation, + **kwargs, + ) + + # Create encoder and decoder instances + self._encoder = JsonParameterEncoder(self._codec) + self._decoder = JsonReturnDecoder(self._codec) + + def encoder(self) -> JsonParameterEncoder: + """ + Get the parameter encoder instance. + + :return: The parameter encoder. + :rtype: JsonParameterEncoder + """ + return self._encoder + + def decoder(self) -> JsonReturnDecoder: + """ + Get the return value decoder instance. + + :return: The return value decoder. + :rtype: JsonReturnDecoder + """ + return self._decoder + + # Direct access methods for convenience + def encode_parameters(self, *arguments) -> bytes: + """Direct parameter encoding.""" + return self._codec.encode_parameters(*arguments) + + def decode_return_value(self, data: bytes) -> Any: + """Direct return value decoding.""" + return self._codec.decode_return_value(data) + + # Properties for access to internal codec if needed + @property + def codec(self) -> JsonTransportCodec: + """Access to the underlying codec.""" + return self._codec diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index c2f0252..15df8bf 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -14,137 +14,172 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional - -from .json_transport_base import DeserializationException, SerializationException, SimpleRegistry -from .json_transport_codec import ( - CollectionHandler, - DataclassHandler, +from typing import Any, Type, List, Optional +from dubbo.codec.json_codec import ( + JsonCodec, + TypeHandler, + StandardJsonCodec, + OrJsonCodec, + UJsonCodec, DateTimeHandler, + PydanticHandler, + CollectionHandler, DecimalHandler, + SimpleTypesHandler, EnumHandler, - OrJsonPlugin, - PydanticHandler, - SimpleTypeHandler, - StandardJsonPlugin, - UJsonPlugin, + DataclassHandler, ) +__all__ = ["JsonTransportCodec", "SerializationException", "DeserializationException"] + + +class SerializationException(Exception): + """Exception raised during serialization""" + + pass -class JsonTransportEncoder: - """JSON Transport Encoder""" + +class DeserializationException(Exception): + """Exception raised during deserialization""" + + pass + + +class JsonTransportCodec: + """ + JSON Transport Codec + """ def __init__( self, - parameter_types: list[type] | None = None, + parameter_types: Optional[List[Type]] = None, + return_type: Optional[Type] = None, maximum_depth: int = 100, strict_validation: bool = True, **kwargs, ): self.parameter_types = parameter_types or [] + self.return_type = return_type self.maximum_depth = maximum_depth self.strict_validation = strict_validation - self.registry = SimpleRegistry() - self.json_plugins: list[Any] = [] - - # Setup plugins - self._register_default_type_plugins() - self._setup_json_serializer_plugins() - - def _register_default_type_plugins(self): - """Register default type handler plugins""" - default_plugins = [ - DateTimeHandler(), - DecimalHandler(), - CollectionHandler(), - SimpleTypeHandler(), - DataclassHandler(), - EnumHandler(), - ] - - # Add Pydantic plugin if available - pydantic_plugin = PydanticHandler() - if pydantic_plugin.available: - default_plugins.append(pydantic_plugin) - - for plugin in default_plugins: - self.registry.register_plugin(plugin) - - def _setup_json_serializer_plugins(self): - """Setup JSON serializer plugins in priority order""" - # Try orjson first (fastest), then ujson, finally standard json - orjson_plugin = OrJsonPlugin() - if orjson_plugin.available: - self.json_plugins.append(orjson_plugin) - - ujson_plugin = UJsonPlugin() - if ujson_plugin.available: - self.json_plugins.append(ujson_plugin) - - # Always have standard json as fallback - self.json_plugins.append(StandardJsonPlugin()) - - def register_type_provider(self, provider): - """Register custom type provider for backward compatibility""" - self.registry.register_plugin(provider) - - def encode(self, arguments: tuple, parameter_type: list[type] | None = None) -> bytes: - """Encode arguments with flexible parameter handling""" + + # Initialize codecs and handlers using the extension pattern + self._json_codecs = self._setup_json_codecs() + self._type_handlers = self._setup_type_handlers() + + def _setup_json_codecs(self) -> List[JsonCodec]: + """ + Setup JSON codecs in priority order. + Following the compression pattern: try fastest first, fallback to standard. + """ + codecs = [] + + # Try orjson first (fastest) + orjson_codec = OrJsonCodec() + if orjson_codec.can_handle(None): # Check availability + codecs.append(orjson_codec) + + # Try ujson second + ujson_codec = UJsonCodec() + if ujson_codec.can_handle(None): # Check availability + codecs.append(ujson_codec) + + # Always include standard json as fallback + codecs.append(StandardJsonCodec()) + + return codecs + + def _setup_type_handlers(self) -> List[TypeHandler]: + """ + Setup type handlers for different object types. + Similar to compression - each handler is independent and focused. + """ + handlers = [] + + # Add all available handlers + handlers.append(DateTimeHandler()) + + pydantic_handler = PydanticHandler() + if pydantic_handler.available: + handlers.append(pydantic_handler) + + handlers.extend( + [ + DecimalHandler(), + CollectionHandler(), + SimpleTypesHandler(), + EnumHandler(), + DataclassHandler(), + ] + ) + + return handlers + + def encode_parameters(self, *arguments) -> bytes: + """ + Encode parameters to JSON bytes. + + :param arguments: The arguments to encode. + :return: Encoded JSON bytes. + :rtype: bytes + """ try: if not arguments: - return self._serialize_to_json_bytes([]) + return self._encode_with_codecs([]) # Handle single parameter case - if parameter_type and len(parameter_type) == 1: - parameter = arguments[0] - serialized_param = self._serialize_object(parameter) - return self._serialize_to_json_bytes(serialized_param) + if len(self.parameter_types) == 1: + serialized = self._serialize_object(arguments[0]) + return self._encode_with_codecs(serialized) # Handle multiple parameters - elif parameter_type and len(parameter_type) > 1: - # Try Pydantic wrapper for strong typing - pydantic_handler = self._get_pydantic_handler() - if pydantic_handler and pydantic_handler.available: - wrapper_data = {f"param_{i}": arg for i, arg in enumerate(arguments)} - wrapper_model = pydantic_handler.create_parameter_model(self.parameter_types) - if wrapper_model: - try: - wrapper_instance = wrapper_model(**wrapper_data) - serialized_wrapper = self._serialize_object(wrapper_instance) - return self._serialize_to_json_bytes(serialized_wrapper) - except Exception: - pass # Fall back to standard handling - - # Standard multi-parameter handling + elif len(self.parameter_types) > 1: serialized_args = [self._serialize_object(arg) for arg in arguments] - return self._serialize_to_json_bytes(serialized_args) + return self._encode_with_codecs(serialized_args) + # No type constraints else: - # No type constraints - serialize as single object if only one argument if len(arguments) == 1: - serialized_obj = self._serialize_object(arguments[0]) - return self._serialize_to_json_bytes(serialized_obj) + serialized = self._serialize_object(arguments[0]) + return self._encode_with_codecs(serialized) else: - # Multiple arguments - serialize as list serialized_args = [self._serialize_object(arg) for arg in arguments] - return self._serialize_to_json_bytes(serialized_args) + return self._encode_with_codecs(serialized_args) except Exception as e: - raise SerializationException(f"Encoding failed: {e}") from e + raise SerializationException(f"Parameter encoding failed: {e}") from e + + def decode_return_value(self, data: bytes) -> Any: + """ + Decode return value from JSON bytes. + + :param data: The JSON bytes to decode. + :type data: bytes + :return: Decoded return value. + :rtype: Any + """ + try: + if not data: + return None + + json_data = self._decode_with_codecs(data) + return self._reconstruct_objects(json_data) - def _get_pydantic_handler(self) -> Optional[PydanticHandler]: - """Get Pydantic handler from registered plugins""" - for plugin in self.registry.plugins: - if isinstance(plugin, PydanticHandler): - return plugin - return None + except Exception as e: + raise DeserializationException(f"Return value decoding failed: {e}") from e def _serialize_object(self, obj: Any, depth: int = 0) -> Any: - """Serialize single object using registry with depth protection""" + """ + Serialize an object using the appropriate type handler. + + :param obj: The object to serialize. + :param depth: Current serialization depth. + :return: Serialized representation. + """ if depth > self.maximum_depth: raise SerializationException(f"Maximum depth {self.maximum_depth} exceeded") - # Handle primitives + # Handle simple types if obj is None or isinstance(obj, (bool, int, float, str)): return obj @@ -162,167 +197,69 @@ def _serialize_object(self, obj: Any, depth: int = 0) -> Any: result[key] = self._serialize_object(value, depth + 1) return result - # Use registry to find handler - handler = self.registry.get_handler(obj) - if handler: - try: - serialized = handler(obj) - # Recursively serialize the result from the handler - return self._serialize_object(serialized, depth + 1) - except Exception as e: - if self.strict_validation: - raise SerializationException(f"Handler failed for {type(obj).__name__}: {e}") from e - return {"__serialization_error__": str(e), "__type__": type(obj).__name__} + # Use type handlers for complex objects + obj_type = type(obj) + for handler in self._type_handlers: + if handler.can_serialize_type(obj, obj_type): + try: + serialized = handler.serialize_to_dict(obj) + return self._serialize_object(serialized, depth + 1) + except Exception as e: + if self.strict_validation: + raise SerializationException(f"Handler failed for {type(obj).__name__}: {e}") from e + return {"__serialization_error__": str(e), "__type__": type(obj).__name__} # Fallback for unknown types if self.strict_validation: raise SerializationException(f"No handler for type {type(obj).__name__}") return {"__fallback__": str(obj), "__type__": type(obj).__name__} - def _serialize_to_json_bytes(self, obj: Any) -> bytes: - """Use the first available JSON plugin to serialize""" + def _encode_with_codecs(self, obj: Any) -> bytes: + """ + Encode object using the first available JSON codec. + + :param obj: The object to encode. + :return: JSON bytes. + :rtype: bytes + """ last_error = None - for plugin in self.json_plugins: + for codec in self._json_codecs: try: - return plugin.encode(obj) + return codec.encode(obj) except Exception as e: last_error = e continue - raise SerializationException(f"All JSON plugins failed. Last error: {last_error}") - - -class JsonTransportDecoder: - """JSON Transport Decoder""" - - def __init__(self, target_type: type | list[type] | None = None, **kwargs): - self.target_type = target_type - self.json_plugins: list[Any] = [] - self._setup_json_deserializer_plugins() - - # Handle multiple parameter types - if isinstance(target_type, list): - self.multiple_parameter_mode = len(target_type) > 1 - self.parameter_types = target_type - if self.multiple_parameter_mode: - pydantic_handler = PydanticHandler() - if pydantic_handler.available: - self.parameter_wrapper_model = pydantic_handler.create_parameter_model(target_type) - else: - self.multiple_parameter_mode = False - self.parameter_types = [target_type] if target_type else [] - - def _setup_json_deserializer_plugins(self): - """Setup JSON deserializer plugins in priority order""" - orjson_plugin = OrJsonPlugin() - if orjson_plugin.available: - self.json_plugins.append(orjson_plugin) + raise SerializationException(f"All JSON codecs failed. Last error: {last_error}") - ujson_plugin = UJsonPlugin() - if ujson_plugin.available: - self.json_plugins.append(ujson_plugin) + def _decode_with_codecs(self, data: bytes) -> Any: + """ + Decode JSON bytes using the first available codec. - self.json_plugins.append(StandardJsonPlugin()) - - def decode(self, data: bytes) -> Any: - """Decode JSON bytes back to objects""" - try: - if not data: - return None - - json_data = self._deserialize_from_json_bytes(data) - reconstructed_data = self._reconstruct_objects(json_data) - - # Handle single-item list unpacking if target type expects one value - if ( - isinstance(reconstructed_data, list) - and len(reconstructed_data) == 1 - and self.target_type - and not isinstance(self.target_type, list) - ): - single_item = reconstructed_data[0] - if isinstance(single_item, self.target_type): - return single_item - reconstructed_data = single_item - - # Handle [single] target_type inside a list - elif ( - isinstance(reconstructed_data, list) - and len(reconstructed_data) == 1 - and isinstance(self.target_type, list) - and len(self.target_type) == 1 - ): - single_item = reconstructed_data[0] - target_type = self.target_type[0] - if isinstance(single_item, target_type): - return single_item - - if not self.target_type: - return reconstructed_data - - if isinstance(self.target_type, list): - if self.multiple_parameter_mode and hasattr(self, "parameter_wrapper_model"): - try: - wrapper_instance = self.parameter_wrapper_model(**reconstructed_data) - return tuple(getattr(wrapper_instance, f"param_{i}") for i in range(len(self.parameter_types))) - except Exception: - pass - - # Decode to first type if available - if self.parameter_types: - return self._decode_to_target_type(reconstructed_data, self.parameter_types[0]) - return reconstructed_data - else: - return self._decode_to_target_type(reconstructed_data, self.target_type) - - except Exception as e: - raise DeserializationException(f"Decoding failed: {e}") from e - - def _deserialize_from_json_bytes(self, data: bytes) -> Any: - """Use the first available JSON plugin to deserialize""" + :param data: The JSON bytes to decode. + :return: Decoded object. + :rtype: Any + """ last_error = None - for plugin in self.json_plugins: + + for codec in self._json_codecs: try: - return plugin.decode(data) + return codec.decode(data) except Exception as e: last_error = e continue - raise DeserializationException(f"All JSON plugins failed. Last error: {last_error}") - - def _decode_to_target_type(self, json_data: Any, target_type: type) -> Any: - """Convert JSON data to target type with proper Pydantic handling""" - - if isinstance(json_data, target_type): - return json_data - - # Special handling for Pydantic models - try: - from pydantic import BaseModel - - if isinstance(target_type, type) and issubclass(target_type, BaseModel): - if isinstance(json_data, target_type): - return json_data - elif isinstance(json_data, dict): - return target_type(**json_data) - elif isinstance(json_data, list) and len(json_data) == 1: - return self._decode_to_target_type(json_data[0], target_type) - elif isinstance(json_data, list) and isinstance(json_data[0], dict): - return self._decode_to_target_type(json_data[0], target_type) - - except ImportError: - pass - - # Handle built-in simple types - if target_type in (str, int, float, bool, list, dict): - return target_type(json_data) - - return json_data + raise DeserializationException(f"All JSON codecs failed. Last error: {last_error}") def _reconstruct_objects(self, data: Any) -> Any: - """Reconstruct special objects from their serialized form""" + """ + Reconstruct objects from their serialized form. + :param data: The data to reconstruct. + :return: Reconstructed object. + :rtype: Any + """ if not isinstance(data, dict): if isinstance(data, list): return [self._reconstruct_objects(item) for item in data] @@ -360,20 +297,9 @@ def _reconstruct_objects(self, data: Any) -> Any: elif "__pydantic_model__" in data and "__model_data__" in data: return self._reconstruct_pydantic_model(data) elif "__dataclass__" in data: - module_name, class_name = data["__dataclass__"].rsplit(".", 1) - import importlib - - module = importlib.import_module(module_name) - cls = getattr(module, class_name) - fields = self._reconstruct_objects(data["fields"]) - return cls(**fields) + return self._reconstruct_dataclass(data) elif "__enum__" in data: - module_name, class_name = data["__enum__"].rsplit(".", 1) - import importlib - - module = importlib.import_module(module_name) - cls = getattr(module, class_name) - return cls(data["value"]) + return self._reconstruct_enum(data) else: return {key: self._reconstruct_objects(value) for key, value in data.items()} @@ -382,6 +308,7 @@ def _reconstruct_pydantic_model(self, data: dict) -> Any: try: model_path = data["__pydantic_model__"] model_data = data["__model_data__"] + module_name, class_name = model_path.rsplit(".", 1) import importlib @@ -394,45 +321,25 @@ def _reconstruct_pydantic_model(self, data: dict) -> Any: except Exception: return self._reconstruct_objects(data.get("__model_data__", {})) + def _reconstruct_dataclass(self, data: dict) -> Any: + """Reconstruct a dataclass from serialized data""" + module_name, class_name = data["__dataclass__"].rsplit(".", 1) -class JsonTransportCodec: - """JSON transport codec""" + import importlib - def __init__( - self, - parameter_types: list[type] | None = None, - return_type: type | None = None, - maximum_depth: int = 100, - strict_validation: bool = True, - **kwargs, - ): - self.parameter_types = parameter_types or [] - self.return_type = return_type - self.maximum_depth = maximum_depth - self.strict_validation = strict_validation - - self._encoder = JsonTransportEncoder( - parameter_types=parameter_types, - maximum_depth=maximum_depth, - strict_validation=strict_validation, - **kwargs, - ) - self._decoder = JsonTransportDecoder(target_type=return_type, **kwargs) + module = importlib.import_module(module_name) + cls = getattr(module, class_name) - def encode_parameters(self, *arguments, parameter_type: list[type] | None = None) -> bytes: - """Encode parameters - supports both positional and keyword args""" - return self._encoder.encode(arguments, parameter_type=parameter_type) + fields = self._reconstruct_objects(data["fields"]) + return cls(**fields) - def decode_return_value(self, data: bytes) -> Any: - """Decode return value""" - return self._decoder.decode(data) + def _reconstruct_enum(self, data: dict) -> Any: + """Reconstruct an enum from serialized data""" + module_name, class_name = data["__enum__"].rsplit(".", 1) - def get_encoder(self) -> JsonTransportEncoder: - return self._encoder + import importlib - def get_decoder(self) -> JsonTransportDecoder: - return self._decoder + module = importlib.import_module(module_name) + cls = getattr(module, class_name) - def register_type_provider(self, provider) -> None: - """Register custom type provider""" - self._encoder.register_type_provider(provider) + return cls(data["value"]) diff --git a/src/dubbo/codec/json_codec/json_transport_base.py b/src/dubbo/codec/json_codec/json_transport_base.py deleted file mode 100644 index a62be03..0000000 --- a/src/dubbo/codec/json_codec/json_transport_base.py +++ /dev/null @@ -1,71 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Callable, Protocol, Optional - - -class JsonSerializerPlugin(Protocol): - """Protocol for JSON serialization plugins""" - - def encode(self, obj: Any) -> bytes: ... - def decode(self, data: bytes) -> Any: ... - def can_handle(self, obj: Any) -> bool: ... - - -class TypeHandlerPlugin(Protocol): - """Protocol for type-specific serialization""" - - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: ... - def serialize_to_dict(self, obj: Any) -> Any: ... - - -class SimpleRegistry: - """Simplified registry using dict instead of complex TypeProviderRegistry""" - - def __init__(self): - # Simple dict mapping: type -> handler function - self.type_handlers: dict[type, Callable[..., Any]] = {} - self.plugins: list[TypeHandlerPlugin] = [] - - def register_type_handler(self, obj_type: type, handler: Callable[..., Any]) -> None: - """Register a simple type handler function""" - self.type_handlers[obj_type] = handler - - def register_plugin(self, plugin: TypeHandlerPlugin) -> None: - """Register a plugin""" - self.plugins.append(plugin) - - def get_handler(self, obj: Any) -> Optional[Callable[..., Any]]: - """Get handler for object - check dict first, then plugins""" - obj_type = type(obj) - if obj_type in self.type_handlers: - return self.type_handlers[obj_type] - - for plugin in self.plugins: - if plugin.can_serialize_type(obj, obj_type): - return plugin.serialize_to_dict - return None - - -class SerializationException(Exception): - """Exception raised during serialization""" - - pass - - -class DeserializationException(Exception): - """Exception raised during deserialization""" - - pass diff --git a/src/dubbo/codec/json_codec/json_transport_codec.py b/src/dubbo/codec/json_codec/json_transport_codec.py deleted file mode 100644 index 075efef..0000000 --- a/src/dubbo/codec/json_codec/json_transport_codec.py +++ /dev/null @@ -1,231 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import json -from dataclasses import asdict, is_dataclass -from datetime import date, datetime, time -from decimal import Decimal -from enum import Enum -from pathlib import Path -from typing import Any, Union -from uuid import UUID -from typing import Optional - - -class StandardJsonPlugin: - """Standard library JSON codec""" - - def encode(self, obj: Any) -> bytes: - return json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8") - - def decode(self, data: bytes) -> Any: - return json.loads(data.decode("utf-8")) - - def can_handle(self, obj: Any) -> bool: - return True - - -class OrJsonPlugin: - """orjson Codec independent implementation""" - - def __init__(self): - try: - import orjson - - self.orjson = orjson - self.available = True - except ImportError: - self.available = False - - def encode(self, obj: Any) -> bytes: - if not self.available: - raise ImportError("orjson not available") - return self.orjson.dumps(obj, default=self._default_handler) - - def decode(self, data: bytes) -> Any: - if not self.available: - raise ImportError("orjson not available") - return self.orjson.loads(data) - - def can_handle(self, obj: Any) -> bool: - return self.available - - def _default_handler(self, obj): - """Handle types not supported natively by orjson""" - if isinstance(obj, datetime): - return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} - elif isinstance(obj, date): - return {"__date__": obj.isoformat()} - elif isinstance(obj, time): - return {"__time__": obj.isoformat()} - elif isinstance(obj, Decimal): - return {"__decimal__": str(obj)} - elif isinstance(obj, set): - return {"__set__": list(obj)} - elif isinstance(obj, frozenset): - return {"__frozenset__": list(obj)} - elif isinstance(obj, UUID): - return {"__uuid__": str(obj)} - elif isinstance(obj, Path): - return {"__path__": str(obj)} - return {"__fallback__": str(obj), "__type__": type(obj).__name__} - - -class UJsonPlugin: - """ujson plugin implementation""" - - def __init__(self): - try: - import ujson - - self.ujson = ujson - self.available = True - except ImportError: - self.available = False - - def encode(self, obj: Any) -> bytes: - if not self.available: - raise ImportError("ujson not available") - return self.ujson.dumps(obj, ensure_ascii=False, default=self._default_handler).encode("utf-8") - - def decode(self, data: bytes) -> Any: - if not self.available: - raise ImportError("ujson not available") - return self.ujson.loads(data.decode("utf-8")) - - def can_handle(self, obj: Any) -> bool: - return self.available - - def _default_handler(self, obj): - """Handle types not supported natively by ujson""" - if isinstance(obj, datetime): - return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} - elif isinstance(obj, date): - return {"__date__": obj.isoformat()} - elif isinstance(obj, time): - return {"__time__": obj.isoformat()} - elif isinstance(obj, Decimal): - return {"__decimal__": str(obj)} - elif isinstance(obj, set): - return {"__set__": list(obj)} - elif isinstance(obj, frozenset): - return {"__frozenset__": list(obj)} - elif isinstance(obj, UUID): - return {"__uuid__": str(obj)} - elif isinstance(obj, Path): - return {"__path__": str(obj)} - return {"__fallback__": str(obj), "__type__": type(obj).__name__} - - -class DateTimeHandler: - """DateTime handler - implements TypeHandler Codec""" - - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return isinstance(obj, (datetime, date, time)) - - def serialize_to_dict(self, obj: Union[datetime, date, time]) -> dict[str, str | None]: - if isinstance(obj, datetime): - return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} - elif isinstance(obj, date): - return {"__date__": obj.isoformat()} - else: - return {"__time__": obj.isoformat()} - - -class DecimalHandler: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type is Decimal - - def serialize_to_dict(self, obj: Decimal) -> dict[str, str]: - return {"__decimal__": str(obj)} - - -class CollectionHandler: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type in (set, frozenset) - - def serialize_to_dict(self, obj: Union[set, frozenset]) -> dict[str, list]: - return {"__frozenset__" if isinstance(obj, frozenset) else "__set__": list(obj)} - - -class EnumHandler: - """Handles serialization of Enum types""" - - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return isinstance(obj, Enum) - - def serialize_to_dict(self, obj: Enum) -> dict[str, Any]: - return {"__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "value": obj.value} - - -class DataclassHandler: - """Handles serialization of dataclass types""" - - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return is_dataclass(obj) - - def serialize_to_dict(self, obj: Any) -> dict[str, Any]: - return {"__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "fields": asdict(obj)} - - -class SimpleTypeHandler: - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return obj_type in (UUID, Path) or isinstance(obj, Path) - - def serialize_to_dict(self, obj: Union[UUID, Path]) -> dict[str, str]: - if isinstance(obj, UUID): - return {"__uuid__": str(obj)} - elif isinstance(obj, Path): - return {"__path__": str(obj)} - - -class PydanticHandler: - """Pydantic codec for handling advance serialization""" - - def __init__(self): - try: - from pydantic import BaseModel, create_model - - self.BaseModel = BaseModel - self.create_model = create_model - self.available = True - except ImportError: - self.available = False - - def can_serialize_type(self, obj: Any, obj_type: type) -> bool: - return self.available and isinstance(obj, self.BaseModel) - - def serialize_to_dict(self, obj) -> dict[str, Any]: - if hasattr(obj, "model_dump"): - return { - "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", - "__model_data__": obj.model_dump(), - } - return { - "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", - "__model_data__": obj.dict(), - } - - def create_parameter_model(self, parameter_types: Optional[list[type]] = None): - """Enhanced parameter handling for both positional and keyword args""" - if not self.available or parameter_types is None: - return None - - model_fields = {} - for i, param_type in enumerate(parameter_types): - model_fields[f"param_{i}"] = (param_type, ...) - return self.create_model("ParametersModel", **model_fields) diff --git a/src/dubbo/codec/json_codec/orjson_codec.py b/src/dubbo/codec/json_codec/orjson_codec.py new file mode 100644 index 0000000..1e692ed --- /dev/null +++ b/src/dubbo/codec/json_codec/orjson_codec.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import date, datetime, time +from decimal import Decimal +from pathlib import Path +from typing import Any +from uuid import UUID + +from dubbo.codec.json_codec import JsonCodec + +__all__ = ["OrJsonCodec"] + + +class OrJsonCodec(JsonCodec): + """ + orjson codec implementation for high-performance JSON encoding/decoding. + + Uses the orjson library if available, otherwise falls back gracefully. + """ + + def __init__(self): + try: + import orjson + + self.orjson = orjson + self.available = True + except ImportError: + self.available = False + + def encode(self, obj: Any) -> bytes: + """ + Encode an object to JSON bytes using orjson. + + :param obj: The object to encode. + :type obj: Any + :return: The encoded JSON bytes. + :rtype: bytes + """ + if not self.available: + raise ImportError("orjson not available") + return self.orjson.dumps(obj, default=self._default_handler) + + def decode(self, data: bytes) -> Any: + """ + Decode JSON bytes to an object using orjson. + + :param data: The JSON bytes to decode. + :type data: bytes + :return: The decoded object. + :rtype: Any + """ + if not self.available: + raise ImportError("orjson not available") + return self.orjson.loads(data) + + def can_handle(self, obj: Any) -> bool: + """ + Check if this codec can handle the given object. + + :param obj: The object to check. + :type obj: Any + :return: True if orjson is available. + :rtype: bool + """ + return self.available + + def _default_handler(self, obj): + """ + Handle types not supported natively by orjson. + + :param obj: The object to serialize. + :return: Serialized representation. + """ + if isinstance(obj, datetime): + return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, set): + return {"__set__": list(obj)} + elif isinstance(obj, frozenset): + return {"__frozenset__": list(obj)} + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + return {"__fallback__": str(obj), "__type__": type(obj).__name__} diff --git a/src/dubbo/codec/json_codec/pydantic_handler.py b/src/dubbo/codec/json_codec/pydantic_handler.py new file mode 100644 index 0000000..90a661c --- /dev/null +++ b/src/dubbo/codec/json_codec/pydantic_handler.py @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, List, Type + +from dubbo.codec.json_codec import TypeHandler + +__all__ = ["PydanticHandler"] + + +class PydanticHandler(TypeHandler): + """ + Type handler for Pydantic models. + + Handles serialization of Pydantic BaseModel instances with proper + model reconstruction support. + """ + + def __init__(self): + try: + from pydantic import BaseModel, create_model + + self.BaseModel = BaseModel + self.create_model = create_model + self.available = True + except ImportError: + self.available = False + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Check if this handler can serialize Pydantic models. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if object is a Pydantic BaseModel and library is available. + :rtype: bool + """ + return self.available and isinstance(obj, self.BaseModel) + + def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + """ + Serialize Pydantic model to dictionary representation. + + :param obj: The Pydantic model to serialize. + :type obj: BaseModel + :return: Dictionary representation with model metadata. + :rtype: Dict[str, Any] + """ + if not self.available: + raise ImportError("Pydantic not available") + + # Use model_dump if available (Pydantic v2), otherwise use dict (Pydantic v1) + if hasattr(obj, "model_dump"): + model_data = obj.model_dump() + else: + model_data = obj.dict() + + return { + "__pydantic_model__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", + "__model_data__": model_data, + } + + def create_parameter_model(self, parameter_types: Optional[List[Type]] = None): + """ + Create a Pydantic model for parameter wrapping. + + :param parameter_types: List of parameter types to wrap. + :type parameter_types: Optional[List[Type]] + :return: Dynamically created Pydantic model or None. + """ + if not self.available or parameter_types is None: + return None + + model_fields = {f"param_{i}": (param_type, ...) for i, param_type in enumerate(parameter_types)} + return self.create_model("ParametersModel", **model_fields) diff --git a/src/dubbo/codec/json_codec/simple_types_handler.py b/src/dubbo/codec/json_codec/simple_types_handler.py new file mode 100644 index 0000000..05dfb15 --- /dev/null +++ b/src/dubbo/codec/json_codec/simple_types_handler.py @@ -0,0 +1,60 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path +from typing import Any, Dict, Union +from uuid import UUID + +from dubbo.codec.json_codec import TypeHandler + +__all__ = ["SimpleTypesHandler"] + + +class SimpleTypesHandler(TypeHandler): + """ + Type handler for simple types like UUID and Path. + + Handles serialization of UUID and Path objects to string representations. + """ + + def can_serialize_type(self, obj: Any, obj_type: type) -> bool: + """ + Check if this handler can serialize simple types. + + :param obj: The object to check. + :type obj: Any + :param obj_type: The type of the object. + :type obj_type: type + :return: True if object is UUID or Path. + :rtype: bool + """ + return obj_type in (UUID, Path) or isinstance(obj, Path) + + def serialize_to_dict(self, obj: Union[UUID, Path]) -> Dict[str, str]: + """ + Serialize UUID or Path to dictionary representation. + + :param obj: The object to serialize. + :type obj: Union[UUID, Path] + :return: Dictionary representation with type marker. + :rtype: Dict[str, str] + """ + if isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + else: + raise ValueError(f"Unsupported simple type: {type(obj)}") diff --git a/src/dubbo/codec/json_codec/standard_json.py b/src/dubbo/codec/json_codec/standard_json.py new file mode 100644 index 0000000..150ccfc --- /dev/null +++ b/src/dubbo/codec/json_codec/standard_json.py @@ -0,0 +1,65 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any + +from dubbo.codec.json_codec import JsonCodec + +__all__ = ["StandardJsonCodec"] + + +class StandardJsonCodec(JsonCodec): + """ + Standard library JSON codec implementation. + + Uses Python's built-in json module for encoding and decoding. + This is the fallback codec that can handle any object. + """ + + def encode(self, obj: Any) -> bytes: + """ + Encode an object to JSON bytes using standard library. + + :param obj: The object to encode. + :type obj: Any + :return: The encoded JSON bytes. + :rtype: bytes + """ + return json.dumps(obj, ensure_ascii=False, separators=(",", ":")).encode("utf-8") + + def decode(self, data: bytes) -> Any: + """ + Decode JSON bytes to an object using standard library. + + :param data: The JSON bytes to decode. + :type data: bytes + :return: The decoded object. + :rtype: Any + """ + return json.loads(data.decode("utf-8")) + + def can_handle(self, obj: Any) -> bool: + """ + Check if this codec can handle the given object. + Standard JSON can handle any object as fallback. + + :param obj: The object to check. + :type obj: Any + :return: Always True (fallback codec). + :rtype: bool + """ + return True diff --git a/src/dubbo/codec/json_codec/ujson_codec.py b/src/dubbo/codec/json_codec/ujson_codec.py new file mode 100644 index 0000000..1b6fe47 --- /dev/null +++ b/src/dubbo/codec/json_codec/ujson_codec.py @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import date, datetime, time +from decimal import Decimal +from pathlib import Path +from typing import Any +from uuid import UUID + +from dubbo.codec.json_codec import JsonCodec + +__all__ = ["UJsonCodec"] + + +class UJsonCodec(JsonCodec): + """ + ujson codec implementation for high-performance JSON encoding/decoding. + + Uses the ujson library if available, otherwise falls back gracefully. + """ + + def __init__(self): + try: + import ujson + + self.ujson = ujson + self.available = True + except ImportError: + self.available = False + + def encode(self, obj: Any) -> bytes: + """ + Encode an object to JSON bytes using ujson. + + :param obj: The object to encode. + :type obj: Any + :return: The encoded JSON bytes. + :rtype: bytes + """ + if not self.available: + raise ImportError("ujson not available") + return self.ujson.dumps(obj, ensure_ascii=False, default=self._default_handler).encode("utf-8") + + def decode(self, data: bytes) -> Any: + """ + Decode JSON bytes to an object using ujson. + + :param data: The JSON bytes to decode. + :type data: bytes + :return: The decoded object. + :rtype: Any + """ + if not self.available: + raise ImportError("ujson not available") + return self.ujson.loads(data.decode("utf-8")) + + def can_handle(self, obj: Any) -> bool: + """ + Check if this codec can handle the given object. + + :param obj: The object to check. + :type obj: Any + :return: True if ujson is available. + :rtype: bool + """ + return self.available + + def _default_handler(self, obj): + """ + Handle types not supported natively by ujson. + + :param obj: The object to serialize. + :return: Serialized representation. + """ + if isinstance(obj, datetime): + return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + elif isinstance(obj, date): + return {"__date__": obj.isoformat()} + elif isinstance(obj, time): + return {"__time__": obj.isoformat()} + elif isinstance(obj, Decimal): + return {"__decimal__": str(obj)} + elif isinstance(obj, set): + return {"__set__": list(obj)} + elif isinstance(obj, frozenset): + return {"__frozenset__": list(obj)} + elif isinstance(obj, UUID): + return {"__uuid__": str(obj)} + elif isinstance(obj, Path): + return {"__path__": str(obj)} + return {"__fallback__": str(obj), "__type__": type(obj).__name__} diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index bed9d0c..ba42e9b 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -14,90 +14,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import importlib from dataclasses import dataclass from typing import Any -from dubbo.classes import Codec, SingletonBase - -# Import all the required interface classes from dubbo.cluster import LoadBalance from dubbo.compression import Compressor, Decompressor -from dubbo.extension import registries as registries_module from dubbo.protocol import Protocol from dubbo.registry import RegistryFactory from dubbo.remoting import Transporter - - -class ExtensionError(Exception): - """ - Extension error. - """ - - def __init__(self, message: str): - """ - Initialize the extension error. - :param message: The error message. - :type message: str - """ - super().__init__(message) - - -class ExtensionLoader(SingletonBase): - """ - Singleton class for loading extension implementations. - """ - - def __init__(self): - """ - Initialize the extension loader. - - Load all the registries from the registries module. - """ - if not hasattr(self, "_initialized"): # Ensure __init__ runs only once - self._registries = {} - for name in registries_module.registries: - registry = getattr(registries_module, name) - self._registries[registry.interface] = registry.impls - self._initialized = True - - def get_extension(self, interface: Any, impl_name: str) -> Any: - """ - Get the extension implementation for the interface. - - :param interface: Interface class. - :type interface: Any - :param impl_name: Implementation name. - :type impl_name: str - :return: Extension implementation class. - :rtype: Any - :raises ExtensionError: If the interface or implementation is not found. - """ - # Get the registry for the interface - impls = self._registries.get(interface) - print("value is ", impls, interface) - if not impls: - raise ExtensionError(f"Interface '{interface.__name__}' is not supported.") - - # Get the full name of the implementation - full_name = impls.get(impl_name) - if not full_name: - raise ExtensionError(f"Implementation '{impl_name}' for interface '{interface.__name__}' is not exist.") - - try: - # Split the full name into module and class - module_name, class_name = full_name.rsplit(".", 1) - - # Load the module and get the class - module = importlib.import_module(module_name) - subclass = getattr(module, class_name) - - # Return the subclass - return subclass - except Exception as e: - raise ExtensionError( - f"Failed to load extension '{impl_name}' for interface '{interface.__name__}'. \nDetail: {e}" - ) +from dubbo.classes import Codec +from dubbo.codec.json_codec import TypeHandler +from dubbo.codec.protobuf_codec import EncodingStrategy, DecodingStrategy @dataclass @@ -123,7 +50,10 @@ class ExtendedRegistry: "compressorRegistry", "decompressorRegistry", "transporterRegistry", + "encodingHandlerRegistry", + "decodingHandlerRegistry", "codecRegistry", + "typeHandlerRegistry", ] # RegistryFactory registry @@ -179,11 +109,42 @@ class ExtendedRegistry: }, ) +# Encoding Strategy Registries +encodingHandlerRegistry = ExtendedRegistry( + interface=EncodingStrategy, + impls={ + "message": "dubbo.codec.protobuf_codec.MessageEncodingStrategy", + "primitive": "dubbo.codec.protobuf_codec.PrimitiveEncodingStrategy", + }, +) + +# Decoding Strategy Registries +decodingHandlerRegistry = ExtendedRegistry( + interface=DecodingStrategy, + impls={ + "message": "dubbo.codec.protobuf_codec.MessageDecodingStrategy", + "primitive": "dubbo.codec.protobuf_codec.PrimitiveDecodingStrategy", + }, +) + # Codec Registry codecRegistry = ExtendedRegistry( interface=Codec, impls={ - "json": "dubbo.codec.json_codec.JsonTransportCodec", + "json": "dubbo.codec.json_codec.JsonTransportCodecBridge", "protobuf": "dubbo.codec.protobuf_codec.ProtobufTransportCodec", }, ) + +typeHandlerRegistry = ExtendedRegistry( + interface=TypeHandler, + impls={ + "datetime": "dubbo.codec.json_codec.DateTimeHandler", + "decimal": "dubbo.codec.json_codec.DecimalHandler", + "collection": "dubbo.codec.json_codec.CollectionHandler", + "enum": "dubbo.codec.json_codec.EnumHandler", + "dataclass": "dubbo.codec.json_codec.DataclassHandler", + "simple": "dubbo.codec.json_codec.SimpleTypeHandler", + "pydantic": "dubbo.codec.json_codec.PydanticHandler", + }, +) From 1393855678c63f032ab6975629045a15cab7a6dc Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 4 Sep 2025 21:48:43 +0000 Subject: [PATCH 22/40] fixing bugs --- src/dubbo/codec/dubbo_codec.py | 156 ++++++- src/dubbo/codec/protobuf_codec/__init__.py | 2 +- .../protobuf_codec/betterproto_handler.py | 196 +++++++++ .../codec/protobuf_codec/protobuf_base.py | 103 +++++ .../codec/protobuf_codec/protobuf_codec.py | 165 +++++++ .../protobuf_codec/protobuf_codec_handler.py | 405 ------------------ src/dubbo/extension/registries.py | 23 +- src/dubbo/proxy/handlers.py | 2 + 8 files changed, 605 insertions(+), 447 deletions(-) create mode 100644 src/dubbo/codec/protobuf_codec/betterproto_handler.py create mode 100644 src/dubbo/codec/protobuf_codec/protobuf_base.py create mode 100644 src/dubbo/codec/protobuf_codec/protobuf_codec.py delete mode 100644 src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index f127a16..740ac54 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -14,10 +14,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import abc import inspect import logging from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, List, Tuple + +__all__ = [ + "ParameterDescriptor", + "MethodDescriptor", + "TransportCodec", + "SerializationEncoder", + "SerializationDecoder", + "DubboSerializationService", +] logger = logging.getLogger(__name__) @@ -38,23 +48,100 @@ class MethodDescriptor: function: Callable name: str - parameters: list[ParameterDescriptor] + parameters: List[ParameterDescriptor] return_parameter: ParameterDescriptor documentation: Optional[str] = None +class TransportCodec(abc.ABC): + """ + The transport codec interface. + """ + + @classmethod + @abc.abstractmethod + def get_transport_type(cls) -> str: + """ + Get transport type of current codec + :return: The transport type. + :rtype: str + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_encoder(self) -> "SerializationEncoder": + """ + Get encoder instance + :return: The encoder. + :rtype: SerializationEncoder + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_decoder(self) -> "SerializationDecoder": + """ + Get decoder instance + :return: The decoder. + :rtype: SerializationDecoder + """ + raise NotImplementedError() + + +class SerializationEncoder(abc.ABC): + """ + The serialization encoder interface. + """ + + @abc.abstractmethod + def encode(self, arguments: Tuple[Any, ...]) -> bytes: + """ + Encode arguments to bytes. + :param arguments: The arguments to encode. + :type arguments: Tuple[Any, ...] + :return: The encoded bytes. + :rtype: bytes + """ + raise NotImplementedError() + + +class SerializationDecoder(abc.ABC): + """ + The serialization decoder interface. + """ + + @abc.abstractmethod + def decode(self, data: bytes) -> Any: + """ + Decode bytes to object. + :param data: The data to decode. + :type data: bytes + :return: The decoded object. + :rtype: Any + """ + raise NotImplementedError() + + class DubboSerializationService: """Dubbo serialization service with type handling""" @staticmethod def create_transport_codec( transport_type: str = "json", - parameter_types: Optional[list[type]] = None, + parameter_types: Optional[List[type]] = None, return_type: Optional[type] = None, **codec_options, - ): - """Create transport codec""" - + ) -> TransportCodec: + """ + Create transport codec + + :param transport_type: The transport type (e.g., 'json', 'protobuf') + :param parameter_types: List of parameter types + :param return_type: Return value type + :param codec_options: Additional codec options + :return: Transport codec instance + :raises ImportError: If required modules cannot be imported + :raises Exception: If codec creation fails + """ try: from dubbo.classes import CodecHelper from dubbo.extension.extension_loader import ExtensionLoader @@ -71,12 +158,21 @@ def create_transport_codec( @staticmethod def create_encoder_decoder_pair( transport_type: str, - parameter_types: Optional[list[type]] = None, + parameter_types: Optional[List[type]] = None, return_type: Optional[type] = None, **codec_options, - ) -> tuple[Any, Any]: - """Create encoder and decoder instances""" - + ) -> Tuple[SerializationEncoder, SerializationDecoder]: + """ + Create encoder and decoder instances + + :param transport_type: The transport type + :param parameter_types: List of parameter types + :param return_type: Return value type + :param codec_options: Additional codec options + :return: Tuple of (encoder, decoder) + :raises ValueError: If codec returns None encoder/decoder + :raises Exception: If creation fails + """ try: codec_instance = DubboSerializationService.create_transport_codec( transport_type=transport_type, @@ -85,8 +181,8 @@ def create_encoder_decoder_pair( **codec_options, ) - encoder = codec_instance.get_encoder() - decoder = codec_instance.get_decoder() + encoder = codec_instance.encoder() + decoder = codec_instance.decoder() if encoder is None or decoder is None: raise ValueError(f"Codec for transport type '{transport_type}' returned None encoder/decoder") @@ -100,12 +196,20 @@ def create_encoder_decoder_pair( @staticmethod def create_serialization_functions( transport_type: str, - parameter_types: Optional[list[type]] = None, + parameter_types: Optional[List[type]] = None, return_type: Optional[type] = None, **codec_options, - ) -> tuple[Callable, Callable]: - """Create serializer and deserializer functions""" - + ) -> Tuple[Callable[..., bytes], Callable[[bytes], Any]]: + """ + Create serializer and deserializer functions + + :param transport_type: The transport type + :param parameter_types: List of parameter types + :param return_type: Return value type + :param codec_options: Additional codec options + :return: Tuple of (serializer_function, deserializer_function) + :raises Exception: If creation fails + """ try: parameter_encoder, return_decoder = DubboSerializationService.create_encoder_decoder_pair( transport_type=transport_type, @@ -115,13 +219,15 @@ def create_serialization_functions( ) def serialize_method_parameters(*args) -> bytes: + """Serialize method parameters to bytes""" try: return parameter_encoder.encode(args) except Exception as e: logger.error("Failed to serialize parameters: %s", e) raise - def deserialize_method_return(data: bytes): + def deserialize_method_return(data: bytes) -> Any: + """Deserialize bytes to return value""" if not isinstance(data, bytes): raise TypeError(f"Expected bytes, got {type(data)}") try: @@ -140,12 +246,22 @@ def deserialize_method_return(data: bytes): def create_method_descriptor( func: Callable, method_name: Optional[str] = None, - parameter_types: Optional[list[type]] = None, + parameter_types: Optional[List[type]] = None, return_type: Optional[type] = None, interface: Optional[Callable[..., Any]] = None, ) -> MethodDescriptor: - """Create a method descriptor from function and configuration""" - + """ + Create a method descriptor from function and configuration + + :param func: The function to create descriptor for + :param method_name: Override method name + :param parameter_types: Override parameter types + :param return_type: Override return type + :param interface: Interface to use for signature inspection + :return: Method descriptor + :raises TypeError: If func is not callable + :raises ValueError: If signature cannot be inspected + """ if not callable(func): raise TypeError("func must be callable") diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index d224a47..ad98c72 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .protobuf_codec_handler import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder +from .protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder __all__ = [ "ProtobufTransportCodec", diff --git a/src/dubbo/codec/protobuf_codec/betterproto_handler.py b/src/dubbo/codec/protobuf_codec/betterproto_handler.py new file mode 100644 index 0000000..f7edc84 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/betterproto_handler.py @@ -0,0 +1,196 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Optional + +from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException + +try: + import betterproto + + HAS_BETTERPROTO = True +except ImportError: + HAS_BETTERPROTO = False + +__all__ = ["BetterprotoMessageHandler", "PrimitiveHandler"] + + +class BetterprotoMessageHandler(ProtobufEncoder, ProtobufDecoder): + """ + The BetterProto message handler for protobuf messages. + """ + + _SERIALIZATION_TYPE = "betterproto" + + def __init__(self): + if not HAS_BETTERPROTO: + raise ImportError("betterproto library is required for BetterprotoMessageHandler") + + @classmethod + def get_serialization_type(cls) -> str: + """ + Get serialization type of current implementation + :return: The serialization type. + :rtype: str + """ + return cls._SERIALIZATION_TYPE + + def can_handle(self, obj: Any, obj_type: Optional[type] = None) -> bool: + """ + Check if this handler can handle the given object/type + :param obj: The object to check + :param obj_type: The type to check + :return: True if can handle, False otherwise + :rtype: bool + """ + if obj is not None and isinstance(obj, betterproto.Message): + return True + if obj_type is not None: + return self._is_betterproto_message(obj_type) + return False + + def encode(self, obj: Any, obj_type: Optional[type] = None) -> bytes: + """ + Encode the betterproto message to bytes. + :param obj: The message to encode. + :param obj_type: The type hint for encoding. + :return: The encoded bytes. + :rtype: bytes + """ + try: + if isinstance(obj, betterproto.Message): + return bytes(obj) + + if obj_type and self._is_betterproto_message(obj_type): + if isinstance(obj, obj_type): + return bytes(obj) + elif isinstance(obj, dict): + message = obj_type().from_dict(obj) + return bytes(message) + else: + raise SerializationException(f"Cannot convert {type(obj)} to {obj_type}") + + raise SerializationException(f"Cannot encode {type(obj)} as betterproto message") + except Exception as e: + raise SerializationException(f"BetterProto encoding failed: {e}") from e + + def decode(self, data: bytes, target_type: type) -> Any: + """ + Decode the data to betterproto message. + :param data: The data to decode. + :param target_type: The target message type. + :return: The decoded message. + :rtype: Any + """ + try: + if not self._is_betterproto_message(target_type): + raise DeserializationException(f"{target_type} is not a betterproto message type") + return target_type().parse(data) + except Exception as e: + raise DeserializationException(f"BetterProto decoding failed: {e}") from e + + def _is_betterproto_message(self, obj_type: type) -> bool: + """Check if the type is a betterproto message""" + try: + return hasattr(obj_type, "__dataclass_fields__") and issubclass(obj_type, betterproto.Message) + except (TypeError, AttributeError): + return False + + +class PrimitiveHandler(ProtobufEncoder, ProtobufDecoder): + """ + The primitive type handler for basic Python types. + """ + + _SERIALIZATION_TYPE = "primitive" + + @classmethod + def get_serialization_type(cls) -> str: + """ + Get serialization type of current implementation + :return: The serialization type. + :rtype: str + """ + return cls._SERIALIZATION_TYPE + + def can_handle(self, obj: Any, obj_type: Optional[type] = None) -> bool: + """ + Check if this handler can handle the given object/type + :param obj: The object to check + :param obj_type: The type to check + :return: True if can handle, False otherwise + :rtype: bool + """ + if obj is not None: + return isinstance(obj, (str, int, float, bool, bytes)) + if obj_type is not None: + return obj_type in (str, int, float, bool, bytes) + return False + + def encode(self, obj: Any, obj_type: Optional[type] = None) -> bytes: + """ + Encode the primitive object to bytes. + :param obj: The object to encode. + :param obj_type: The type hint for encoding. + :return: The encoded bytes. + :rtype: bytes + """ + try: + if not isinstance(obj, (str, int, float, bool, bytes)): + raise SerializationException(f"Cannot encode {type(obj)} as primitive") + + json_str = json.dumps({"value": obj, "type": type(obj).__name__}) + return json_str.encode("utf-8") + except Exception as e: + raise SerializationException(f"Primitive encoding failed: {e}") from e + + def decode(self, data: bytes, target_type: type) -> Any: + """ + Decode the data to primitive object. + :param data: The data to decode. + :param target_type: The target primitive type. + :return: The decoded object. + :rtype: Any + """ + try: + if target_type not in (str, int, float, bool, bytes): + raise DeserializationException(f"{target_type} is not a supported primitive type") + + json_str = data.decode("utf-8") + parsed = json.loads(json_str) + value = parsed.get("value") + + if target_type is str: + return str(value) + elif target_type is int: + return int(value) + elif target_type is float: + return float(value) + elif target_type is bool: + return bool(value) + elif target_type is bytes: + if isinstance(value, bytes): + return value + elif isinstance(value, list): + return bytes(value) + else: + return str(value).encode() + else: + return value + + except Exception as e: + raise DeserializationException(f"Primitive decoding failed: {e}") from e diff --git a/src/dubbo/codec/protobuf_codec/protobuf_base.py b/src/dubbo/codec/protobuf_codec/protobuf_base.py new file mode 100644 index 0000000..df188b7 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/protobuf_base.py @@ -0,0 +1,103 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +from typing import Any, Optional + +__all__ = [ + "SerializationException", + "DeserializationException", + "ProtobufSerialization", + "ProtobufEncoder", + "ProtobufDecoder", +] + + +class SerializationException(Exception): + """Exception raised when encoding or serialization fails.""" + + def __init__(self, message: str, *, cause: Optional[Exception] = None): + super().__init__(message) + self.cause = cause + + +class DeserializationException(Exception): + """Exception raised when decoding or deserialization fails.""" + + def __init__(self, message: str, *, cause: Optional[Exception] = None): + super().__init__(message) + self.cause = cause + + +class ProtobufSerialization(abc.ABC): + """ + The protobuf serialization interface. + """ + + @classmethod + @abc.abstractmethod + def get_serialization_type(cls) -> str: + """ + Get serialization type of current implementation + :return: The serialization type. + :rtype: str + """ + raise NotImplementedError() + + @abc.abstractmethod + def can_handle(self, obj: Any, obj_type: Optional[type] = None) -> bool: + """ + Check if this serialization can handle the given object/type + :param obj: The object to check + :param obj_type: The type to check + :return: True if can handle, False otherwise + :rtype: bool + """ + raise NotImplementedError() + + +class ProtobufEncoder(ProtobufSerialization, abc.ABC): + """ + The protobuf encoding interface. + """ + + @abc.abstractmethod + def encode(self, obj: Any, obj_type: Optional[type] = None) -> bytes: + """ + Encode the object to bytes. + :param obj: The object to encode. + :param obj_type: The type hint for encoding. + :return: The encoded bytes. + :rtype: bytes + """ + raise NotImplementedError() + + +class ProtobufDecoder(ProtobufSerialization, abc.ABC): + """ + The protobuf decoding interface. + """ + + @abc.abstractmethod + def decode(self, data: bytes, target_type: type) -> Any: + """ + Decode the data to object. + :param data: The data to decode. + :param target_type: The target type for decoding. + :return: The decoded object. + :rtype: Any + """ + raise NotImplementedError() diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec.py b/src/dubbo/codec/protobuf_codec/protobuf_codec.py new file mode 100644 index 0000000..df1fe11 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec.py @@ -0,0 +1,165 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional, List + +from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException +from .betterproto_handler import BetterprotoMessageHandler, PrimitiveHandler + +__all__ = ["ProtobufTransportCodec"] + + +class ProtobufTransportEncoder: + """Protobuf encoder for parameters""" + + def __init__(self, handlers: List[ProtobufEncoder], parameter_types: Optional[List[type]] = None): + self._handlers = handlers + self._parameter_types = parameter_types or [] + + def encode(self, arguments: tuple) -> bytes: + """Encode arguments tuple to bytes""" + try: + if not arguments: + return b"" + if len(arguments) == 1: + return self._encode_single(arguments[0]) + raise SerializationException( + f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1." + ) + except Exception as e: + raise SerializationException(f"Parameter encoding failed: {e}") from e + + def _encode_single(self, argument: Any) -> bytes: + """Encode a single argument""" + if argument is None: + return b"" + + # Try to get parameter type from configuration + param_type = self._parameter_types[0] if self._parameter_types else None + + for handler in self._handlers: + if handler.can_handle(argument, param_type): + return handler.encode(argument, param_type) + + raise SerializationException(f"No handler found for {type(argument)}") + + +class ProtobufTransportDecoder: + """Protobuf decoder for return values""" + + def __init__(self, handlers: List[ProtobufDecoder], return_type: Optional[type] = None): + self._handlers = handlers + self._return_type = return_type + + def decode(self, data: bytes) -> Any: + """Decode bytes to return value""" + try: + if not data: + return None + if not self._return_type: + raise DeserializationException("No return_type specified for decoding") + + for handler in self._handlers: + if handler.can_handle(None, self._return_type): + return handler.decode(data, self._return_type) + + raise DeserializationException(f"No handler found for {self._return_type}") + except Exception as e: + raise DeserializationException(f"Return value decoding failed: {e}") from e + + +class ProtobufTransportCodec: + """ + Main protobuf codec class compatible with extension loader. + This class provides encoder() and decoder() methods as expected by the extension loader. + """ + + def __init__( + self, + parameter_types: Optional[List[type]] = None, + return_type: Optional[type] = None, + **kwargs, + ): + self._parameter_types = parameter_types or [] + self._return_type = return_type + + # Initialize handlers + self._encoders: List[ProtobufEncoder] = [] + self._decoders: List[ProtobufDecoder] = [] + + # Load default handlers + self._load_default_handlers() + + def _load_default_handlers(self): + """Load default encoding and decoding handlers""" + try: + # Try to load BetterProto handler + betterproto_handler = BetterprotoMessageHandler() + self._encoders.append(betterproto_handler) + self._decoders.append(betterproto_handler) + except ImportError: + print("Warning: BetterProto handler not available") + + # Load primitive handler + primitive_handler = PrimitiveHandler() + self._encoders.append(primitive_handler) + self._decoders.append(primitive_handler) + + def encoder(self) -> ProtobufTransportEncoder: + """ + Create and return an encoder instance. + This method is called by the extension loader / DubboSerializationService. + """ + return ProtobufTransportEncoder(self._encoders, self._parameter_types) + + def decoder(self) -> ProtobufTransportDecoder: + """ + Create and return a decoder instance. + This method is called by the extension loader / DubboSerializationService. + """ + return ProtobufTransportDecoder(self._decoders, self._return_type) + + # Convenience methods for direct usage (backward compatibility) + def encode_parameter(self, argument: Any) -> bytes: + """Encode a single parameter""" + encoder = self.encoder() + return encoder.encode((argument,)) + + def encode_parameters(self, arguments: tuple) -> bytes: + """Encode parameters tuple""" + encoder = self.encoder() + return encoder.encode(arguments) + + def decode_return_value(self, data: bytes) -> Any: + """Decode return value""" + decoder = self.decoder() + return decoder.decode(data) + + def register_encoder(self, encoder: ProtobufEncoder): + """Register a custom encoder""" + self._encoders.append(encoder) + + def register_decoder(self, decoder: ProtobufDecoder): + """Register a custom decoder""" + self._decoders.append(decoder) + + def get_encoders(self) -> List[ProtobufEncoder]: + """Get all registered encoders""" + return self._encoders.copy() + + def get_decoders(self) -> List[ProtobufDecoder]: + """Get all registered decoders""" + return self._decoders.copy() diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py b/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py deleted file mode 100644 index 53acd35..0000000 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec_handler.py +++ /dev/null @@ -1,405 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" -# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Optional, Protocol - -# Betterproto imports -try: - import betterproto - - HAS_BETTERPROTO = True -except ImportError: - HAS_BETTERPROTO = False - - -class SerializationException(Exception): - """Exception raised when encoding or serialization fails.""" - - def __init__(self, message: str, *, cause: Optional[Exception] = None): - super().__init__(message) - self.cause = cause - - -class DeserializationException(Exception): - """Exception raised when decoding or deserialization fails.""" - - def __init__(self, message: str, *, cause: Optional[Exception] = None): - super().__init__(message) - self.cause = cause - - -class ProtobufEncodingFunction(Protocol): - def __call__(self, obj: Any) -> bytes: ... - - -class ProtobufDecodingFunction(Protocol): - def __call__(self, data: bytes) -> Any: ... - - -@dataclass -class ProtobufMethodDescriptor: - """Protobuf-specific method descriptor for single parameter""" - - parameter_type: Optional[type] - return_type: Optional[type] - protobuf_message_type: Optional[type] = None - - -# Abstract base classes for pluggable architecture -class TypeHandler(ABC): - """Abstract base class for type handlers""" - - @abstractmethod - def is_message(self, obj_type: type) -> bool: ... - - @abstractmethod - def is_message_instance(self, obj: Any) -> bool: ... - - @abstractmethod - def is_compatible(self, obj_type: type) -> bool: ... - - -class EncodingStrategy(ABC): - """Abstract base class for encoding strategies""" - - @abstractmethod - def can_encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bool: ... - - @abstractmethod - def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: ... - - -class DecodingStrategy(ABC): - """Abstract base class for decoding strategies""" - - @abstractmethod - def can_decode(self, data: bytes, target_type: type) -> bool: ... - - @abstractmethod - def decode(self, data: bytes, target_type: type) -> Any: ... - - -# Concrete implementations -class ProtobufTypeHandler(TypeHandler): - """Handles type conversion between Python types and Betterproto""" - - def is_message(self, obj_type: type) -> bool: - if not HAS_BETTERPROTO: - return False - try: - return hasattr(obj_type, "__dataclass_fields__") and issubclass(obj_type, betterproto.Message) - except (TypeError, AttributeError): - return False - - def is_message_instance(self, obj: Any) -> bool: - if not HAS_BETTERPROTO: - return False - return isinstance(obj, betterproto.Message) - - def is_compatible(self, obj_type: type) -> bool: - return obj_type in (str, int, float, bool, bytes) or self.is_message(obj_type) - - # Static methods for backward compatibility - @staticmethod - def is_betterproto_message(obj_type: type) -> bool: - handler = ProtobufTypeHandler() - return handler.is_message(obj_type) - - @staticmethod - def is_betterproto_message_instance(obj: Any) -> bool: - handler = ProtobufTypeHandler() - return handler.is_message_instance(obj) - - @staticmethod - def is_protobuf_compatible(obj_type: type) -> bool: - handler = ProtobufTypeHandler() - return handler.is_compatible(obj_type) - - -class MessageEncodingStrategy(EncodingStrategy): - """Encoding strategy for protobuf messages""" - - def __init__(self, type_handler: TypeHandler): - self.type_handler = type_handler - - def can_encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bool: - return self.type_handler.is_message_instance(parameter) or ( - parameter_type is not None and self.type_handler.is_message(parameter_type) - ) - - def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: - if self.type_handler.is_message_instance(parameter): - return bytes(parameter) - - if parameter_type and self.type_handler.is_message(parameter_type): - if isinstance(parameter, parameter_type): - return bytes(parameter) - elif isinstance(parameter, dict): - try: - message = parameter_type().from_dict(parameter) - return bytes(message) - except Exception as e: - raise SerializationException(f"Cannot convert dict to {parameter_type}: {e}") - else: - raise SerializationException(f"Cannot convert {type(parameter)} to {parameter_type}") - - raise SerializationException(f"Cannot encode {type(parameter)} as protobuf message") - - -class PrimitiveEncodingStrategy(EncodingStrategy): - """Encoding strategy for primitive types""" - - def can_encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bool: - return isinstance(parameter, (str, int, float, bool, bytes)) - - def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: - try: - json_str = json.dumps({"value": parameter, "type": type(parameter).__name__}) - return json_str.encode("utf-8") - except Exception as e: - raise SerializationException(f"Failed to encode primitive {parameter}: {e}") - - -class MessageDecodingStrategy(DecodingStrategy): - """Decoding strategy for protobuf messages""" - - def __init__(self, type_handler: TypeHandler): - self.type_handler = type_handler - - def can_decode(self, data: bytes, target_type: type) -> bool: - return self.type_handler.is_message(target_type) - - def decode(self, data: bytes, target_type: type) -> Any: - try: - return target_type().parse(data) - except Exception as e: - raise DeserializationException(f"Failed to parse betterproto message: {e}") - - -class PrimitiveDecodingStrategy(DecodingStrategy): - """Decoding strategy for primitive types""" - - def can_decode(self, data: bytes, target_type: type) -> bool: - return target_type in (str, int, float, bool, bytes) - - def decode(self, data: bytes, target_type: type) -> Any: - try: - json_str = data.decode("utf-8") - parsed = json.loads(json_str) - value = parsed.get("value") - - if target_type is str: - return str(value) - elif target_type is int: - return int(value) - elif target_type is float: - return float(value) - elif target_type is bool: - return bool(value) - elif target_type is bytes: - if isinstance(value, bytes): - return value - elif isinstance(value, list): - return bytes(value) - else: - return str(value).encode() - else: - return value - - except Exception as e: - raise DeserializationException(f"Failed to decode primitive: {e}") - - -class StrategyRegistry: - """Registry for managing encoding/decoding strategies""" - - def __init__(self): - self.encoding_strategies: list[EncodingStrategy] = [] - self.decoding_strategies: list[DecodingStrategy] = [] - - def register_encoding_strategy(self, strategy: EncodingStrategy) -> None: - self.encoding_strategies.append(strategy) - - def register_decoding_strategy(self, strategy: DecodingStrategy) -> None: - self.decoding_strategies.append(strategy) - - def find_encoding_strategy( - self, parameter: Any, parameter_type: Optional[type] = None - ) -> Optional[EncodingStrategy]: - for strategy in self.encoding_strategies: - if strategy.can_encode(parameter, parameter_type): - return strategy - return None - - def find_decoding_strategy(self, data: bytes, target_type: type) -> Optional[DecodingStrategy]: - for strategy in self.decoding_strategies: - if strategy.can_decode(data, target_type): - return strategy - return None - - -class ProtobufTransportEncoder: - """Protobuf encoder for single parameters using pluggable strategies""" - - def __init__( - self, - parameter_type: Optional[type] = None, - type_handler: Optional[TypeHandler] = None, - strategy_registry: Optional[StrategyRegistry] = None, - **kwargs, - ): - if not HAS_BETTERPROTO: - raise ImportError("betterproto library is required for ProtobufTransportEncoder") - self.parameter_type = parameter_type - self.descriptor = ProtobufMethodDescriptor(parameter_type=parameter_type, return_type=None) - - self.type_handler = type_handler or ProtobufTypeHandler() - self.strategy_registry = strategy_registry or self._create_default_registry() - - def _create_default_registry(self) -> StrategyRegistry: - registry = StrategyRegistry() - registry.register_encoding_strategy(MessageEncodingStrategy(self.type_handler)) - registry.register_encoding_strategy(PrimitiveEncodingStrategy()) - return registry - - def encode(self, parameter: Any, parameter_type: Optional[type] = None) -> bytes: - try: - if parameter is None: - return b"" - - effective_type = parameter_type or self.parameter_type - - if isinstance(parameter, tuple): - if len(parameter) == 0: - return b"" - elif len(parameter) == 1: - return self._encode_single_parameter(parameter[0], effective_type) - else: - raise SerializationException( - f"Multiple parameters not supported. Got tuple with {len(parameter)} elements, expected 1." - ) - - return self._encode_single_parameter(parameter, effective_type) - - except Exception as e: - raise SerializationException(f"Protobuf encoding failed: {e}") from e - - def _encode_single_parameter(self, parameter: Any, parameter_type: Optional[type]) -> bytes: - strategy = self.strategy_registry.find_encoding_strategy(parameter, parameter_type) - if strategy: - return strategy.encode(parameter, parameter_type) - raise SerializationException(f"No encoding strategy found for {type(parameter)}") - - def _encode_primitive(self, value: Any) -> bytes: - strategy = PrimitiveEncodingStrategy() - return strategy.encode(value) - - -class ProtobufTransportDecoder: - """Protobuf decoder for single parameters using pluggable strategies""" - - def __init__( - self, - target_type: Optional[type] = None, - type_handler: Optional[TypeHandler] = None, - strategy_registry: Optional[StrategyRegistry] = None, - **kwargs, - ): - if not HAS_BETTERPROTO: - raise ImportError("betterproto library is required for ProtobufTransportDecoder") - - self.target_type = target_type - self.type_handler = type_handler or ProtobufTypeHandler() - self.strategy_registry = strategy_registry or self._create_default_registry() - - def _create_default_registry(self) -> StrategyRegistry: - registry = StrategyRegistry() - registry.register_decoding_strategy(MessageDecodingStrategy(self.type_handler)) - registry.register_decoding_strategy(PrimitiveDecodingStrategy()) - return registry - - def decode(self, data: bytes) -> Any: - try: - if not data: - return None - if not self.target_type: - raise DeserializationException("No target_type specified for decoding") - return self._decode_single_parameter(data, self.target_type) - except Exception as e: - raise DeserializationException(f"Protobuf decoding failed: {e}") from e - - def _decode_single_parameter(self, data: bytes, target_type: type) -> Any: - strategy = self.strategy_registry.find_decoding_strategy(data, target_type) - if strategy: - return strategy.decode(data, target_type) - raise DeserializationException(f"No decoding strategy found for {target_type}") - - def _decode_primitive(self, data: bytes, target_type: type) -> Any: - strategy = PrimitiveDecodingStrategy() - return strategy.decode(data, target_type) - - -class ProtobufTransportCodec: - """Main protobuf codec class for single parameters""" - - def __init__( - self, - parameter_type: Optional[type] = None, - return_type: Optional[type] = None, - type_handler: Optional[TypeHandler] = None, - encoder_registry: Optional[StrategyRegistry] = None, - decoder_registry: Optional[StrategyRegistry] = None, - **kwargs, - ): - if not HAS_BETTERPROTO: - raise ImportError("betterproto library is required for ProtobufTransportCodec") - - shared_registry = encoder_registry or decoder_registry - - self._encoder = ProtobufTransportEncoder( - parameter_type=parameter_type, - type_handler=type_handler, - strategy_registry=encoder_registry or shared_registry, - **kwargs, - ) - self._decoder = ProtobufTransportDecoder( - target_type=return_type, - type_handler=type_handler, - strategy_registry=decoder_registry or shared_registry, - **kwargs, - ) - - def encode_parameter(self, argument: Any) -> bytes: - return self._encoder.encode(argument, self._encoder.parameter_type) - - def encode_parameters(self, arguments: tuple) -> bytes: - if not arguments: - return b"" - if len(arguments) == 1: - return self._encoder.encode(arguments[0], self._encoder.parameter_type) - raise SerializationException(f"Multiple parameters not supported. Got {len(arguments)} arguments, expected 1.") - - def decode_return_value(self, data: bytes) -> Any: - return self._decoder.decode(data) - - def get_encoder(self) -> ProtobufTransportEncoder: - return self._encoder - - def get_decoder(self) -> ProtobufTransportDecoder: - return self._decoder diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index ba42e9b..38d482e 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -24,7 +24,7 @@ from dubbo.remoting import Transporter from dubbo.classes import Codec from dubbo.codec.json_codec import TypeHandler -from dubbo.codec.protobuf_codec import EncodingStrategy, DecodingStrategy +# from dubbo.codec.protobuf_codec import EncodingStrategy, DecodingStrategy @dataclass @@ -50,8 +50,6 @@ class ExtendedRegistry: "compressorRegistry", "decompressorRegistry", "transporterRegistry", - "encodingHandlerRegistry", - "decodingHandlerRegistry", "codecRegistry", "typeHandlerRegistry", ] @@ -109,24 +107,6 @@ class ExtendedRegistry: }, ) -# Encoding Strategy Registries -encodingHandlerRegistry = ExtendedRegistry( - interface=EncodingStrategy, - impls={ - "message": "dubbo.codec.protobuf_codec.MessageEncodingStrategy", - "primitive": "dubbo.codec.protobuf_codec.PrimitiveEncodingStrategy", - }, -) - -# Decoding Strategy Registries -decodingHandlerRegistry = ExtendedRegistry( - interface=DecodingStrategy, - impls={ - "message": "dubbo.codec.protobuf_codec.MessageDecodingStrategy", - "primitive": "dubbo.codec.protobuf_codec.PrimitiveDecodingStrategy", - }, -) - # Codec Registry codecRegistry = ExtendedRegistry( interface=Codec, @@ -136,6 +116,7 @@ class ExtendedRegistry: }, ) +# TypeHandler registry typeHandlerRegistry = ExtendedRegistry( interface=TypeHandler, impls={ diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index 140cc31..ff3251a 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -85,6 +85,7 @@ def _infer_types_from_method(cls, method: Callable) -> tuple: method_name = method.__name__ params = list(sig.parameters.values()) + # Check for 'self' parameter which indicates an unbound method if params and params[0].name == "self": raise RpcMethodConfigurationError( f"Method '{method_name}' appears to be an unbound method with 'self' parameter. " @@ -93,6 +94,7 @@ def _infer_types_from_method(cls, method: Callable) -> tuple: "RpcMethodHandler.unary(instance.method) not RpcMethodHandler.unary(Class.method)" ) + # For bound methods or standalone functions, all parameters are RPC parameters params_types = [type_hints.get(p.name, Any) for p in params] return_type = type_hints.get("return", Any) return method_name, params_types, return_type From 23f21d97fb3303d38ecb2fab64ccf23ab0aedba3 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Thu, 4 Sep 2025 22:01:36 +0000 Subject: [PATCH 23/40] remove the debug statement --- src/dubbo/extension/registries.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index 38d482e..47cf6e6 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -24,7 +24,6 @@ from dubbo.remoting import Transporter from dubbo.classes import Codec from dubbo.codec.json_codec import TypeHandler -# from dubbo.codec.protobuf_codec import EncodingStrategy, DecodingStrategy @dataclass From e3a5244bb972f56d0702655166e906102b2121a2 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Fri, 5 Sep 2025 17:15:15 +0000 Subject: [PATCH 24/40] added the use of extension loader in the betterproto handler --- src/dubbo/codec/protobuf_codec/__init__.py | 6 ++++++ src/dubbo/codec/protobuf_codec/protobuf_codec.py | 11 ++++++++--- src/dubbo/extension/registries.py | 13 ++++++++++++- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index ad98c72..ded719d 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -15,9 +15,15 @@ # limitations under the License. from .protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder +from .protobuf_base import ProtobufEncoder, ProtobufDecoder +from .betterproto_handler import BetterprotoMessageHandler, PrimitiveHandler __all__ = [ "ProtobufTransportCodec", "ProtobufTransportDecoder", "ProtobufTransportEncoder", + "ProtobufEncoder", + "ProtobufDecoder" + "BetterprotoMessageHandler", + "PrimitiveHandler" ] diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec.py b/src/dubbo/codec/protobuf_codec/protobuf_codec.py index df1fe11..8e3e451 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec.py @@ -17,7 +17,7 @@ from typing import Any, Optional, List from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException -from .betterproto_handler import BetterprotoMessageHandler, PrimitiveHandler +from .protobuf_base import ProtobufEncoder __all__ = ["ProtobufTransportCodec"] @@ -105,16 +105,21 @@ def __init__( def _load_default_handlers(self): """Load default encoding and decoding handlers""" + from dubbo.extension import extensionLoader try: # Try to load BetterProto handler - betterproto_handler = BetterprotoMessageHandler() + name = "message" + message_handler = extensionLoader.get_extension(ProtobufEncoder, name) + betterproto_handler = message_handler() self._encoders.append(betterproto_handler) self._decoders.append(betterproto_handler) except ImportError: print("Warning: BetterProto handler not available") # Load primitive handler - primitive_handler = PrimitiveHandler() + from dubbo.extension import extensionLoader + name = "primitive" + primitive_handler = extensionLoader.get_extension(ProtobufEncoder, name)() self._encoders.append(primitive_handler) self._decoders.append(primitive_handler) diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index 47cf6e6..86cd566 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -24,6 +24,7 @@ from dubbo.remoting import Transporter from dubbo.classes import Codec from dubbo.codec.json_codec import TypeHandler +from dubbo.codec.protobuf_codec import ProtobufEncoder @dataclass @@ -41,7 +42,7 @@ class ExtendedRegistry: impls: dict[str, Any] -# All Extension Registries - FIXED: Added codecRegistry to the list +# All Extension Registries registries = [ "registryFactoryRegistry", "loadBalanceRegistry", @@ -51,6 +52,7 @@ class ExtendedRegistry: "transporterRegistry", "codecRegistry", "typeHandlerRegistry", + "betterprotoRegistry" ] # RegistryFactory registry @@ -115,6 +117,15 @@ class ExtendedRegistry: }, ) +# BetterProtoHandler Registry +betterprotoRegistry = ExtendedRegistry( + interface=ProtobufEncoder, + impls={ + "message" : "dubbo.codec.protobuf_codec.BetterprotoMessageHandler", + "primitive" : "dubbo.codec.protobuf_codec.PrimitiveHandler", + } +) + # TypeHandler registry typeHandlerRegistry = ExtendedRegistry( interface=TypeHandler, From 1b5d66804e336d51ab33d96fcabe0e2a9cdfdae8 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Fri, 5 Sep 2025 17:15:56 +0000 Subject: [PATCH 25/40] using ruff format for formatting --- src/dubbo/codec/protobuf_codec/__init__.py | 5 ++--- src/dubbo/codec/protobuf_codec/protobuf_codec.py | 2 ++ src/dubbo/extension/registries.py | 8 ++++---- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index ded719d..7694a9f 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -23,7 +23,6 @@ "ProtobufTransportDecoder", "ProtobufTransportEncoder", "ProtobufEncoder", - "ProtobufDecoder" - "BetterprotoMessageHandler", - "PrimitiveHandler" + "ProtobufDecoderBetterprotoMessageHandler", + "PrimitiveHandler", ] diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec.py b/src/dubbo/codec/protobuf_codec/protobuf_codec.py index 8e3e451..5ba71a7 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec.py @@ -106,6 +106,7 @@ def __init__( def _load_default_handlers(self): """Load default encoding and decoding handlers""" from dubbo.extension import extensionLoader + try: # Try to load BetterProto handler name = "message" @@ -118,6 +119,7 @@ def _load_default_handlers(self): # Load primitive handler from dubbo.extension import extensionLoader + name = "primitive" primitive_handler = extensionLoader.get_extension(ProtobufEncoder, name)() self._encoders.append(primitive_handler) diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index 86cd566..3fab385 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -52,7 +52,7 @@ class ExtendedRegistry: "transporterRegistry", "codecRegistry", "typeHandlerRegistry", - "betterprotoRegistry" + "betterprotoRegistry", ] # RegistryFactory registry @@ -121,9 +121,9 @@ class ExtendedRegistry: betterprotoRegistry = ExtendedRegistry( interface=ProtobufEncoder, impls={ - "message" : "dubbo.codec.protobuf_codec.BetterprotoMessageHandler", - "primitive" : "dubbo.codec.protobuf_codec.PrimitiveHandler", - } + "message": "dubbo.codec.protobuf_codec.BetterprotoMessageHandler", + "primitive": "dubbo.codec.protobuf_codec.PrimitiveHandler", + }, ) # TypeHandler registry From c2f14242b9eee84fb10517d983cf0e343b0fd7cc Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Fri, 5 Sep 2025 19:14:08 +0000 Subject: [PATCH 26/40] added the protoc handler if the google protoc being used --- src/dubbo/codec/protobuf_codec/__init__.py | 9 +- .../codec/protobuf_codec/primitive_handler.py | 108 ++++++++++++++++++ .../codec/protobuf_codec/protobuf_codec.py | 25 ++-- .../codec/protobuf_codec/protoc_handler.py | 82 +++++++++++++ 4 files changed, 209 insertions(+), 15 deletions(-) create mode 100644 src/dubbo/codec/protobuf_codec/primitive_handler.py create mode 100644 src/dubbo/codec/protobuf_codec/protoc_handler.py diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index 7694a9f..73c403e 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -16,13 +16,18 @@ from .protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder from .protobuf_base import ProtobufEncoder, ProtobufDecoder -from .betterproto_handler import BetterprotoMessageHandler, PrimitiveHandler +from .betterproto_handler import BetterprotoMessageHandler +from .protoc_handler import GoogleProtobufMessageHandler +from .primitive_handler import PrimitiveHandler __all__ = [ "ProtobufTransportCodec", "ProtobufTransportDecoder", "ProtobufTransportEncoder", "ProtobufEncoder", - "ProtobufDecoderBetterprotoMessageHandler", + "ProtobufDecoder", + "BetterprotoMessageHandler", + "PrimitiveHandler", + "GoogleProtobufMessageHandler", "PrimitiveHandler", ] diff --git a/src/dubbo/codec/protobuf_codec/primitive_handler.py b/src/dubbo/codec/protobuf_codec/primitive_handler.py new file mode 100644 index 0000000..ad2d0c8 --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/primitive_handler.py @@ -0,0 +1,108 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from typing import Any, Optional + +from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException + + +__all__ = ["PrimitiveHandler"] + + +class PrimitiveHandler(ProtobufEncoder, ProtobufDecoder): + """ + The primitive type handler for basic Python types. + """ + + _SERIALIZATION_TYPE = "primitive" + + @classmethod + def get_serialization_type(cls) -> str: + """ + Get serialization type of current implementation + :return: The serialization type. + :rtype: str + """ + return cls._SERIALIZATION_TYPE + + def can_handle(self, obj: Any, obj_type: Optional[type] = None) -> bool: + """ + Check if this handler can handle the given object/type + :param obj: The object to check + :param obj_type: The type to check + :return: True if can handle, False otherwise + :rtype: bool + """ + if obj is not None: + return isinstance(obj, (str, int, float, bool, bytes)) + if obj_type is not None: + return obj_type in (str, int, float, bool, bytes) + return False + + def encode(self, obj: Any, obj_type: Optional[type] = None) -> bytes: + """ + Encode the primitive object to bytes. + :param obj: The object to encode. + :param obj_type: The type hint for encoding. + :return: The encoded bytes. + :rtype: bytes + """ + try: + if not isinstance(obj, (str, int, float, bool, bytes)): + raise SerializationException(f"Cannot encode {type(obj)} as primitive") + + json_str = json.dumps({"value": obj, "type": type(obj).__name__}) + return json_str.encode("utf-8") + except Exception as e: + raise SerializationException(f"Primitive encoding failed: {e}") from e + + def decode(self, data: bytes, target_type: type) -> Any: + """ + Decode the data to primitive object. + :param data: The data to decode. + :param target_type: The target primitive type. + :return: The decoded object. + :rtype: Any + """ + try: + if target_type not in (str, int, float, bool, bytes): + raise DeserializationException(f"{target_type} is not a supported primitive type") + + json_str = data.decode("utf-8") + parsed = json.loads(json_str) + value = parsed.get("value") + + if target_type is str: + return str(value) + elif target_type is int: + return int(value) + elif target_type is float: + return float(value) + elif target_type is bool: + return bool(value) + elif target_type is bytes: + if isinstance(value, bytes): + return value + elif isinstance(value, list): + return bytes(value) + else: + return str(value).encode() + else: + return value + + except Exception as e: + raise DeserializationException(f"Primitive decoding failed: {e}") from e diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec.py b/src/dubbo/codec/protobuf_codec/protobuf_codec.py index 5ba71a7..3cbaf4b 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec.py @@ -17,7 +17,6 @@ from typing import Any, Optional, List from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException -from .protobuf_base import ProtobufEncoder __all__ = ["ProtobufTransportCodec"] @@ -83,8 +82,7 @@ def decode(self, data: bytes) -> Any: class ProtobufTransportCodec: """ - Main protobuf codec class compatible with extension loader. - This class provides encoder() and decoder() methods as expected by the extension loader. + Main protobuf codec class """ def __init__( @@ -107,35 +105,36 @@ def _load_default_handlers(self): """Load default encoding and decoding handlers""" from dubbo.extension import extensionLoader + # Try BetterProto handler try: - # Try to load BetterProto handler - name = "message" - message_handler = extensionLoader.get_extension(ProtobufEncoder, name) - betterproto_handler = message_handler() + betterproto_handler = extensionLoader.get_extension(ProtobufEncoder, "betterproto")() self._encoders.append(betterproto_handler) self._decoders.append(betterproto_handler) except ImportError: print("Warning: BetterProto handler not available") - # Load primitive handler - from dubbo.extension import extensionLoader + # Try Google Protoc handler + try: + protoc_handler = extensionLoader.get_extension(ProtobufEncoder, "googleproto")() + self._encoders.append(protoc_handler) + self._decoders.append(protoc_handler) + except ImportError: + print("Warning: Protoc handler not available") - name = "primitive" - primitive_handler = extensionLoader.get_extension(ProtobufEncoder, name)() + # Always load primitive handler + primitive_handler = extensionLoader.get_extension(ProtobufEncoder, "primitive")() self._encoders.append(primitive_handler) self._decoders.append(primitive_handler) def encoder(self) -> ProtobufTransportEncoder: """ Create and return an encoder instance. - This method is called by the extension loader / DubboSerializationService. """ return ProtobufTransportEncoder(self._encoders, self._parameter_types) def decoder(self) -> ProtobufTransportDecoder: """ Create and return a decoder instance. - This method is called by the extension loader / DubboSerializationService. """ return ProtobufTransportDecoder(self._decoders, self._return_type) diff --git a/src/dubbo/codec/protobuf_codec/protoc_handler.py b/src/dubbo/codec/protobuf_codec/protoc_handler.py new file mode 100644 index 0000000..1c32a1a --- /dev/null +++ b/src/dubbo/codec/protobuf_codec/protoc_handler.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional +from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException + +try: + from google.protobuf.message import Message as GoogleMessage + + HAS_PROTOC = True +except ImportError: + HAS_PROTOC = False + + +class GoogleProtobufMessageHandler(ProtobufEncoder, ProtobufDecoder): + """ + The Google protoc message handler for protobuf messages. + """ + + _SERIALIZATION_TYPE = "protoc" + + def __init__(self): + if not HAS_PROTOC: + raise ImportError("google.protobuf is required for GoogleProtobufMessageHandler") + + @classmethod + def get_serialization_type(cls) -> str: + return cls._SERIALIZATION_TYPE + + def can_handle(self, obj: Any, obj_type: Optional[type] = None) -> bool: + if obj is not None and HAS_PROTOC and isinstance(obj, GoogleMessage): + return True + if obj_type is not None: + return self._is_protoc_message(obj_type) + return False + + def encode(self, obj: Any, obj_type: Optional[type] = None) -> bytes: + try: + if isinstance(obj, GoogleMessage): + return obj.SerializeToString() + + if obj_type and self._is_protoc_message(obj_type): + if isinstance(obj, obj_type): + return obj.SerializeToString() + elif isinstance(obj, dict): + message = obj_type(**obj) + return message.SerializeToString() + else: + raise SerializationException(f"Cannot convert {type(obj)} to {obj_type}") + + raise SerializationException(f"Cannot encode {type(obj)} as protoc message") + except Exception as e: + raise SerializationException(f"Protoc encoding failed: {e}") from e + + def decode(self, data: bytes, target_type: type) -> Any: + try: + if not self._is_protoc_message(target_type): + raise DeserializationException(f"{target_type} is not a protoc message type") + message = target_type() + message.ParseFromString(data) + return message + except Exception as e: + raise DeserializationException(f"Protoc decoding failed: {e}") from e + + def _is_protoc_message(self, obj_type: type) -> bool: + try: + return HAS_PROTOC and issubclass(obj_type, GoogleMessage) + except (TypeError, AttributeError): + return False From 25a0496ba81d28555d615e471a5537dc8221fb6a Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Fri, 5 Sep 2025 22:03:42 +0000 Subject: [PATCH 27/40] divide the base line code from dubbo codec in interface --- src/dubbo/codec/_interface.py | 119 +++++++++++++++++++++++++++++++++ src/dubbo/codec/dubbo_codec.py | 90 +------------------------ 2 files changed, 120 insertions(+), 89 deletions(-) create mode 100644 src/dubbo/codec/_interface.py diff --git a/src/dubbo/codec/_interface.py b/src/dubbo/codec/_interface.py new file mode 100644 index 0000000..6f4e552 --- /dev/null +++ b/src/dubbo/codec/_interface.py @@ -0,0 +1,119 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import abc +import logging +from dataclasses import dataclass +from typing import Any, Callable, Optional, List, Tuple + +__all__ = [ + "ParameterDescriptor", + "MethodDescriptor", + "TransportCodec", + "SerializationEncoder", + "SerializationDecoder", +] + +logger = logging.getLogger(__name__) + + +@dataclass +class ParameterDescriptor: + """Information about a method parameter""" + + name: str + annotation: Any + is_required: bool = True + default_value: Any = None + + +@dataclass +class MethodDescriptor: + """Method descriptor with function details""" + + function: Callable + name: str + parameters: List[ParameterDescriptor] + return_parameter: ParameterDescriptor + documentation: Optional[str] = None + + +class TransportCodec(abc.ABC): + """ + The transport codec interface. + """ + + @classmethod + @abc.abstractmethod + def get_transport_type(cls) -> str: + """ + Get transport type of current codec + :return: The transport type. + :rtype: str + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_encoder(self) -> "SerializationEncoder": + """ + Get encoder instance + :return: The encoder. + :rtype: SerializationEncoder + """ + raise NotImplementedError() + + @abc.abstractmethod + def get_decoder(self) -> "SerializationDecoder": + """ + Get decoder instance + :return: The decoder. + :rtype: SerializationDecoder + """ + raise NotImplementedError() + + +class SerializationEncoder(abc.ABC): + """ + The serialization encoder interface. + """ + + @abc.abstractmethod + def encode(self, arguments: Tuple[Any, ...]) -> bytes: + """ + Encode arguments to bytes. + :param arguments: The arguments to encode. + :type arguments: Tuple[Any, ...] + :return: The encoded bytes. + :rtype: bytes + """ + raise NotImplementedError() + + +class SerializationDecoder(abc.ABC): + """ + The serialization decoder interface. + """ + + @abc.abstractmethod + def decode(self, data: bytes) -> Any: + """ + Decode bytes to object. + :param data: The data to decode. + :type data: bytes + :return: The decoded object. + :rtype: Any + """ + raise NotImplementedError() diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index 740ac54..c4437f1 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -19,6 +19,7 @@ import logging from dataclasses import dataclass from typing import Any, Callable, Optional, List, Tuple +from _interface import ParameterDescriptor, MethodDescriptor, SerializationDecoder, SerializationEncoder, TransportCodec __all__ = [ "ParameterDescriptor", @@ -32,95 +33,6 @@ logger = logging.getLogger(__name__) -@dataclass -class ParameterDescriptor: - """Information about a method parameter""" - - name: str - annotation: Any - is_required: bool = True - default_value: Any = None - - -@dataclass -class MethodDescriptor: - """Method descriptor with function details""" - - function: Callable - name: str - parameters: List[ParameterDescriptor] - return_parameter: ParameterDescriptor - documentation: Optional[str] = None - - -class TransportCodec(abc.ABC): - """ - The transport codec interface. - """ - - @classmethod - @abc.abstractmethod - def get_transport_type(cls) -> str: - """ - Get transport type of current codec - :return: The transport type. - :rtype: str - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_encoder(self) -> "SerializationEncoder": - """ - Get encoder instance - :return: The encoder. - :rtype: SerializationEncoder - """ - raise NotImplementedError() - - @abc.abstractmethod - def get_decoder(self) -> "SerializationDecoder": - """ - Get decoder instance - :return: The decoder. - :rtype: SerializationDecoder - """ - raise NotImplementedError() - - -class SerializationEncoder(abc.ABC): - """ - The serialization encoder interface. - """ - - @abc.abstractmethod - def encode(self, arguments: Tuple[Any, ...]) -> bytes: - """ - Encode arguments to bytes. - :param arguments: The arguments to encode. - :type arguments: Tuple[Any, ...] - :return: The encoded bytes. - :rtype: bytes - """ - raise NotImplementedError() - - -class SerializationDecoder(abc.ABC): - """ - The serialization decoder interface. - """ - - @abc.abstractmethod - def decode(self, data: bytes) -> Any: - """ - Decode bytes to object. - :param data: The data to decode. - :type data: bytes - :return: The decoded object. - :rtype: Any - """ - raise NotImplementedError() - - class DubboSerializationService: """Dubbo serialization service with type handling""" From bd2b42bc62b481dcf8ceb09417a37b9922d104c5 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Fri, 5 Sep 2025 22:05:39 +0000 Subject: [PATCH 28/40] fixing the import issue --- src/dubbo/codec/dubbo_codec.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index c4437f1..4153ede 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -19,7 +19,7 @@ import logging from dataclasses import dataclass from typing import Any, Callable, Optional, List, Tuple -from _interface import ParameterDescriptor, MethodDescriptor, SerializationDecoder, SerializationEncoder, TransportCodec +from ._interface import ParameterDescriptor, MethodDescriptor, SerializationDecoder, SerializationEncoder, TransportCodec __all__ = [ "ParameterDescriptor", From a5fd49e91ea7f3630ecb2ec8dbc2ef7ed0b8f61b Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Fri, 5 Sep 2025 22:35:31 +0000 Subject: [PATCH 29/40] add _interface in every file and remove the codehelper from the classes and introduce it in the codec --- src/dubbo/classes.py | 19 --- src/dubbo/client.py | 100 ++++++++++--- src/dubbo/codec/__init__.py | 3 +- src/dubbo/codec/_interface.py | 40 +++++- src/dubbo/codec/dubbo_codec.py | 19 ++- src/dubbo/codec/protobuf_codec/__init__.py | 2 +- .../{protobuf_base.py => _interface.py} | 0 .../protobuf_codec/betterproto_handler.py | 2 +- .../codec/protobuf_codec/primitive_handler.py | 2 +- .../codec/protobuf_codec/protobuf_codec.py | 2 +- .../codec/protobuf_codec/protoc_handler.py | 2 +- src/dubbo/extension/registries.py | 19 +-- src/dubbo/proxy/handlers.py | 135 ++++++++---------- 13 files changed, 210 insertions(+), 135 deletions(-) rename src/dubbo/codec/protobuf_codec/{protobuf_base.py => _interface.py} (100%) diff --git a/src/dubbo/classes.py b/src/dubbo/classes.py index 14728da..754b348 100644 --- a/src/dubbo/classes.py +++ b/src/dubbo/classes.py @@ -246,22 +246,3 @@ class ReadWriteStream(ReadStream, WriteStream, abc.ABC): """ pass - - -class Codec(ABC): - def __init__(self, model_type: Optional[type[Any]] = None, **kwargs): - self.model_type = model_type - - @abstractmethod - def encode(self, data: Any) -> bytes: - pass - - @abstractmethod - def decode(self, data: bytes) -> Any: - pass - - -class CodecHelper: - @staticmethod - def get_class(): - return Codec diff --git a/src/dubbo/client.py b/src/dubbo/client.py index 4cbfb56..f947abf 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -56,20 +56,19 @@ def __init__(self, reference: ReferenceConfig, dubbo: Optional[Dubbo] = None): def _initialize(self): """ - Initialize the invoker. + Initialize the invoker with protocol and URL. """ with self._global_lock: if self._initialized: return - # get the protocol + # get the protocol extension protocol = extensionLoader.get_extension(Protocol, self._reference.protocol)() registry_config = self._dubbo.registry_config - self._protocol = RegistryProtocol(registry_config, protocol) if registry_config else protocol - # build url + # build the reference URL reference_url = self._reference.to_url() if registry_config: self._url = registry_config.to_url().copy() @@ -79,7 +78,7 @@ def _initialize(self): else: self._url = reference_url - # create invoker + # create the invoker using the protocol self._invoker = self._protocol.refer(self._url) self._initialized = True @@ -95,9 +94,19 @@ def _create_rpc_callable( response_deserializer: Optional[DeserializingFunction] = None, ) -> RpcCallable: """ - Create RPC callable with the specified type. + Create an RPC callable with the specified type. + + :param rpc_type: Type of RPC (unary, client_stream, server_stream, bi_stream) + :param method_name: Name of the method to call + :param params_types: List of parameter types + :param return_type: Return type of the method + :param codec: Optional codec to use for serialization + :param request_serializer: Optional custom request serializer + :param response_deserializer: Optional custom response deserializer + :return: RPC callable proxy + :rtype: RpcCallable """ - # Determine serializers + # determine serializers if request_serializer and response_deserializer: req_ser = request_serializer res_deser = response_deserializer @@ -108,7 +117,7 @@ def _create_rpc_callable( return_type=return_type, ) - # Create MethodDescriptor + # create method descriptor descriptor = MethodDescriptor( method_name=method_name, arg_serialization=(req_ser, None), @@ -118,47 +127,106 @@ def _create_rpc_callable( return self._callable(descriptor) - def unary(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def unary( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + codec: Optional[str] = None, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: + """ + Create a unary RPC callable. + """ return self._create_rpc_callable( rpc_type=RpcTypes.UNARY.value, method_name=method_name, params_types=params_types, return_type=return_type, - **kwargs, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, ) - def client_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def client_stream( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + codec: Optional[str] = None, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: + """ + Create a client-streaming RPC callable. + """ return self._create_rpc_callable( rpc_type=RpcTypes.CLIENT_STREAM.value, method_name=method_name, params_types=params_types, return_type=return_type, - **kwargs, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, ) - def server_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def server_stream( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + codec: Optional[str] = None, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: + """ + Create a server-streaming RPC callable. + """ return self._create_rpc_callable( rpc_type=RpcTypes.SERVER_STREAM.value, method_name=method_name, params_types=params_types, return_type=return_type, - **kwargs, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, ) - def bi_stream(self, method_name: str, params_types: List[Type], return_type: Type, **kwargs) -> RpcCallable: + def bi_stream( + self, + method_name: str, + params_types: List[Type], + return_type: Type, + codec: Optional[str] = None, + request_serializer: Optional[SerializingFunction] = None, + response_deserializer: Optional[DeserializingFunction] = None, + ) -> RpcCallable: + """ + Create a bidirectional-streaming RPC callable. + """ return self._create_rpc_callable( rpc_type=RpcTypes.BI_STREAM.value, method_name=method_name, params_types=params_types, return_type=return_type, - **kwargs, + codec=codec, + request_serializer=request_serializer, + response_deserializer=response_deserializer, ) def _callable(self, method_descriptor: MethodDescriptor) -> RpcCallable: """ Generate a proxy for the given method. + + :param method_descriptor: The method descriptor. + :return: The RPC callable proxy. + :rtype: RpcCallable """ + # get invoker URL and clone it url = self._invoker.get_url().copy() url.parameters[common_constants.METHOD_KEY] = method_descriptor.get_method_name() url.attributes[common_constants.METHOD_DESCRIPTOR_KEY] = method_descriptor + + # create proxy callable return self._callable_factory.get_callable(self._invoker, url) diff --git a/src/dubbo/codec/__init__.py b/src/dubbo/codec/__init__.py index c3061bc..88fbaa8 100644 --- a/src/dubbo/codec/__init__.py +++ b/src/dubbo/codec/__init__.py @@ -15,5 +15,6 @@ # limitations under the License. from .dubbo_codec import DubboSerializationService +from ._interface import Codec -__all__ = ["DubboSerializationService"] +__all__ = ["DubboSerializationService", "Codec"] diff --git a/src/dubbo/codec/_interface.py b/src/dubbo/codec/_interface.py index 6f4e552..ad28448 100644 --- a/src/dubbo/codec/_interface.py +++ b/src/dubbo/codec/_interface.py @@ -17,7 +17,7 @@ import abc import logging from dataclasses import dataclass -from typing import Any, Callable, Optional, List, Tuple +from typing import Any, Callable, Optional, List, Tuple, Type __all__ = [ "ParameterDescriptor", @@ -25,6 +25,7 @@ "TransportCodec", "SerializationEncoder", "SerializationDecoder", + "Codec", ] logger = logging.getLogger(__name__) @@ -117,3 +118,40 @@ def decode(self, data: bytes) -> Any: :rtype: Any """ raise NotImplementedError() + + +class Codec(abc.ABC): + """ + Base codec interface for encoding and decoding data. + """ + + def __init__(self, model_type: Optional[Type[Any]] = None, **kwargs): + """ + Initialize a codec + :param model_type: Optional model type for structured encoding/decoding + :type model_type: Optional[Type[Any]] + :param kwargs: Additional codec configuration + """ + self.model_type = model_type + + @abc.abstractmethod + def encode(self, data: Any) -> bytes: + """ + Encode data into bytes + :param data: The data to encode + :type data: Any + :return: Encoded byte representation + :rtype: bytes + """ + raise NotImplementedError() + + @abc.abstractmethod + def decode(self, data: bytes) -> Any: + """ + Decode bytes into object + :param data: The bytes to decode + :type data: bytes + :return: Decoded object + :rtype: Any + """ + raise NotImplementedError() diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index 4153ede..c0c1e1c 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -14,19 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import abc import inspect import logging -from dataclasses import dataclass from typing import Any, Callable, Optional, List, Tuple -from ._interface import ParameterDescriptor, MethodDescriptor, SerializationDecoder, SerializationEncoder, TransportCodec +from ._interface import ( + ParameterDescriptor, + MethodDescriptor, + SerializationDecoder, + SerializationEncoder, + TransportCodec, + Codec, +) __all__ = [ - "ParameterDescriptor", - "MethodDescriptor", - "TransportCodec", - "SerializationEncoder", - "SerializationDecoder", "DubboSerializationService", ] @@ -55,10 +55,9 @@ def create_transport_codec( :raises Exception: If codec creation fails """ try: - from dubbo.classes import CodecHelper from dubbo.extension.extension_loader import ExtensionLoader - codec_class = ExtensionLoader().get_extension(CodecHelper.get_class(), transport_type) + codec_class = ExtensionLoader().get_extension(Codec, transport_type) return codec_class(parameter_types=parameter_types or [], return_type=return_type, **codec_options) except ImportError as e: logger.error("Failed to import required modules: %s", e) diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index 73c403e..dd17f37 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -15,7 +15,7 @@ # limitations under the License. from .protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder -from .protobuf_base import ProtobufEncoder, ProtobufDecoder +from ._interface import ProtobufEncoder, ProtobufDecoder from .betterproto_handler import BetterprotoMessageHandler from .protoc_handler import GoogleProtobufMessageHandler from .primitive_handler import PrimitiveHandler diff --git a/src/dubbo/codec/protobuf_codec/protobuf_base.py b/src/dubbo/codec/protobuf_codec/_interface.py similarity index 100% rename from src/dubbo/codec/protobuf_codec/protobuf_base.py rename to src/dubbo/codec/protobuf_codec/_interface.py diff --git a/src/dubbo/codec/protobuf_codec/betterproto_handler.py b/src/dubbo/codec/protobuf_codec/betterproto_handler.py index f7edc84..97b0487 100644 --- a/src/dubbo/codec/protobuf_codec/betterproto_handler.py +++ b/src/dubbo/codec/protobuf_codec/betterproto_handler.py @@ -17,7 +17,7 @@ import json from typing import Any, Optional -from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException +from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException try: import betterproto diff --git a/src/dubbo/codec/protobuf_codec/primitive_handler.py b/src/dubbo/codec/protobuf_codec/primitive_handler.py index ad2d0c8..79142df 100644 --- a/src/dubbo/codec/protobuf_codec/primitive_handler.py +++ b/src/dubbo/codec/protobuf_codec/primitive_handler.py @@ -17,7 +17,7 @@ import json from typing import Any, Optional -from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException +from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException __all__ = ["PrimitiveHandler"] diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec.py b/src/dubbo/codec/protobuf_codec/protobuf_codec.py index 3cbaf4b..871f53c 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec.py @@ -16,7 +16,7 @@ from typing import Any, Optional, List -from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException +from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException __all__ = ["ProtobufTransportCodec"] diff --git a/src/dubbo/codec/protobuf_codec/protoc_handler.py b/src/dubbo/codec/protobuf_codec/protoc_handler.py index 1c32a1a..67629c7 100644 --- a/src/dubbo/codec/protobuf_codec/protoc_handler.py +++ b/src/dubbo/codec/protobuf_codec/protoc_handler.py @@ -15,7 +15,7 @@ # limitations under the License. from typing import Any, Optional -from .protobuf_base import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException +from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException try: from google.protobuf.message import Message as GoogleMessage diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index 3fab385..6d3edb1 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -22,7 +22,7 @@ from dubbo.protocol import Protocol from dubbo.registry import RegistryFactory from dubbo.remoting import Transporter -from dubbo.classes import Codec +from dubbo.codec import Codec from dubbo.codec.json_codec import TypeHandler from dubbo.codec.protobuf_codec import ProtobufEncoder @@ -51,8 +51,8 @@ class ExtendedRegistry: "decompressorRegistry", "transporterRegistry", "codecRegistry", - "typeHandlerRegistry", - "betterprotoRegistry", + "jsonTypeHandlerRegistry", + "protoHandlerRegistry", ] # RegistryFactory registry @@ -108,7 +108,7 @@ class ExtendedRegistry: }, ) -# Codec Registry +# Codec registry codecRegistry = ExtendedRegistry( interface=Codec, impls={ @@ -117,17 +117,18 @@ class ExtendedRegistry: }, ) -# BetterProtoHandler Registry -betterprotoRegistry = ExtendedRegistry( +# Protobuf handler registry +protoHandlerRegistry = ExtendedRegistry( interface=ProtobufEncoder, impls={ - "message": "dubbo.codec.protobuf_codec.BetterprotoMessageHandler", + "betterproto": "dubbo.codec.protobuf_codec.BetterprotoMessageHandler", "primitive": "dubbo.codec.protobuf_codec.PrimitiveHandler", + "googleproto": "dubbo.codec.protobuf_codec.GoogleProtobufMessageHandler", }, ) -# TypeHandler registry -typeHandlerRegistry = ExtendedRegistry( +# JSON type handler registry +jsonTypeHandlerRegistry = ExtendedRegistry( interface=TypeHandler, impls={ "datetime": "dubbo.codec.json_codec.DateTimeHandler", diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index ff3251a..578d654 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -30,7 +30,7 @@ class RpcMethodConfigurationError(Exception): """ - Raised when RPC method is configured incorrectly. + Raised when an RPC method is configured incorrectly. """ pass @@ -46,7 +46,7 @@ class RpcMethodHandler: def __init__(self, method_descriptor: MethodDescriptor): """ Initialize the RpcMethodHandler - :param method_descriptor: the method descriptor. + :param method_descriptor: the method descriptor :type method_descriptor: MethodDescriptor """ self._method_descriptor = method_descriptor @@ -63,8 +63,8 @@ def method_descriptor(self) -> MethodDescriptor: @staticmethod def get_codec(**kwargs) -> tuple: """ - Get the serialization and deserialization functions based on codec - :param kwargs: codec settings like transport_type, parameter_types, return_type + Get serialization and deserialization functions + :param kwargs: codec configuration like transport_type, parameter_types, return_type :return: serializer and deserializer functions :rtype: Tuple[SerializingFunction, DeserializingFunction] """ @@ -73,11 +73,10 @@ def get_codec(**kwargs) -> tuple: @classmethod def _infer_types_from_method(cls, method: Callable) -> tuple: """ - Infer method name, parameter types, and return type from a callable - :param method: the method to analyze + Infer method name, parameter types, and return type + :param method: the callable method :type method: Callable - :return: tuple of method name, parameter types, return type - :rtype: Tuple[str, list[type], type] + :return: tuple(method_name, param_types, return_type) """ try: type_hints = get_type_hints(method) @@ -85,16 +84,12 @@ def _infer_types_from_method(cls, method: Callable) -> tuple: method_name = method.__name__ params = list(sig.parameters.values()) - # Check for 'self' parameter which indicates an unbound method + # Detect unbound methods if params and params[0].name == "self": raise RpcMethodConfigurationError( - f"Method '{method_name}' appears to be an unbound method with 'self' parameter. " - "RPC methods should be bound methods (e.g., instance.method) or standalone functions. " - "If you're registering a class method, ensure you pass a bound method: " - "RpcMethodHandler.unary(instance.method) not RpcMethodHandler.unary(Class.method)" + f"Method '{method_name}' appears unbound with 'self'. Pass a bound method or standalone function." ) - # For bound methods or standalone functions, all parameters are RPC parameters params_types = [type_hints.get(p.name, Any) for p in params] return_type = type_hints.get("return", Any) return method_name, params_types, return_type @@ -112,25 +107,25 @@ def _create_method_descriptor( return_type: type, rpc_type: str, codec: Optional[str] = None, - param_encoder: Optional[DeserializingFunction] = None, - return_decoder: Optional[SerializingFunction] = None, + request_deserializer: Optional[DeserializingFunction] = None, + response_serializer: Optional[SerializingFunction] = None, **kwargs, ) -> MethodDescriptor: """ Create a MethodDescriptor with serialization configuration - :param method: the actual function/method + :param method: callable method :param method_name: RPC method name - :param params_types: parameter type hints - :param return_type: return type hint - :param rpc_type: type of RPC (unary, stream, etc.) - :param codec: serialization codec (json, pb, etc.) - :param param_encoder: deserialization function - :param return_decoder: serialization function - :param kwargs: additional codec args + :param params_types: parameter types + :param return_type: return type + :param rpc_type: RPC type (unary, client_stream, server_stream, bi_stream) + :param codec: serialization codec + :param request_deserializer: request deserialization function + :param response_serializer: response serialization function + :param kwargs: additional codec arguments :return: MethodDescriptor instance :rtype: MethodDescriptor """ - if param_encoder is None or return_decoder is None: + if request_deserializer is None or response_serializer is None: codec_kwargs = { "transport_type": codec or "json", "parameter_types": params_types, @@ -138,8 +133,8 @@ def _create_method_descriptor( **kwargs, } serializer, deserializer = cls.get_codec(**codec_kwargs) - request_deserializer = param_encoder or deserializer - response_serializer = return_decoder or serializer + request_deserializer = request_deserializer or deserializer + response_serializer = response_serializer or serializer return MethodDescriptor( callable_method=method, @@ -162,22 +157,17 @@ def unary( **kwargs, ) -> "RpcMethodHandler": """ - Register a unary RPC method handler + Create a unary method handler """ - inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) - resolved_method_name = method_name or inferred_name - resolved_param_types = params_types or inferred_param_types - resolved_return_type = return_type or inferred_return_type - codec = codec or "json" - + name, param_types, ret_type = cls._infer_types_from_method(method) return cls( cls._create_method_descriptor( method=method, - method_name=resolved_method_name, - params_types=resolved_param_types, - return_type=resolved_return_type, + method_name=method_name or name, + params_types=params_types or param_types, + return_type=return_type or ret_type, rpc_type=RpcTypes.UNARY.value, - codec=codec, + codec=codec or "json", request_deserializer=request_deserializer, response_serializer=response_serializer, **kwargs, @@ -197,22 +187,17 @@ def client_stream( **kwargs, ) -> "RpcMethodHandler": """ - Register a client-streaming RPC method handler + Create a client-streaming method handler """ - inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) - resolved_method_name = method_name or inferred_name - resolved_param_types = params_types or inferred_param_types - resolved_return_type = return_type or inferred_return_type - resolved_codec = codec or "json" - + name, param_types, ret_type = cls._infer_types_from_method(method) return cls( cls._create_method_descriptor( method=method, - method_name=resolved_method_name, - params_types=resolved_param_types, - return_type=resolved_return_type, + method_name=method_name or name, + params_types=params_types or param_types, + return_type=return_type or ret_type, rpc_type=RpcTypes.CLIENT_STREAM.value, - codec=resolved_codec, + codec=codec or "json", request_deserializer=request_deserializer, response_serializer=response_serializer, **kwargs, @@ -232,22 +217,17 @@ def server_stream( **kwargs, ) -> "RpcMethodHandler": """ - Register a server-streaming RPC method handler + Create a server-streaming method handler """ - inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) - resolved_method_name = method_name or inferred_name - resolved_param_types = params_types or inferred_param_types - resolved_return_type = return_type or inferred_return_type - resolved_codec = codec or "json" - + name, param_types, ret_type = cls._infer_types_from_method(method) return cls( cls._create_method_descriptor( method=method, - method_name=resolved_method_name, - params_types=resolved_param_types, - return_type=resolved_return_type, + method_name=method_name or name, + params_types=params_types or param_types, + return_type=return_type or ret_type, rpc_type=RpcTypes.SERVER_STREAM.value, - codec=resolved_codec, + codec=codec or "json", request_deserializer=request_deserializer, response_serializer=response_serializer, **kwargs, @@ -267,22 +247,17 @@ def bi_stream( **kwargs, ) -> "RpcMethodHandler": """ - Register a bidirectional streaming RPC method handler + Create a bidirectional-streaming method handler """ - inferred_name, inferred_param_types, inferred_return_type = cls._infer_types_from_method(method) - resolved_method_name = method_name or inferred_name - resolved_param_types = params_types or inferred_param_types - resolved_return_type = return_type or inferred_return_type - resolved_codec = codec or "json" - + name, param_types, ret_type = cls._infer_types_from_method(method) return cls( cls._create_method_descriptor( method=method, - method_name=resolved_method_name, - params_types=resolved_param_types, - return_type=resolved_return_type, + method_name=method_name or name, + params_types=params_types or param_types, + return_type=return_type or ret_type, rpc_type=RpcTypes.BI_STREAM.value, - codec=resolved_codec, + codec=codec or "json", request_deserializer=request_deserializer, response_serializer=response_serializer, **kwargs, @@ -292,7 +267,7 @@ def bi_stream( class RpcServiceHandler: """ - Rpc service handler that maps method names to their corresponding RpcMethodHandler. + Rpc service handler that maps method names to their corresponding RpcMethodHandler """ __slots__ = ["_service_name", "_method_handlers"] @@ -300,6 +275,10 @@ class RpcServiceHandler: def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): """ Initialize the RpcServiceHandler + :param service_name: the name of the service + :type service_name: str + :param method_handlers: the list of RPC method handlers + :type method_handlers: list[RpcMethodHandler] """ self._service_name = service_name self._method_handlers: dict[str, RpcMethodHandler] = {} @@ -310,10 +289,18 @@ def __init__(self, service_name: str, method_handlers: list[RpcMethodHandler]): @property def service_name(self) -> str: - """Get the service name""" + """ + Get the service name + :return: the service name + :rtype: str + """ return self._service_name @property def method_handlers(self) -> dict[str, RpcMethodHandler]: - """Get the registered RPC method handlers""" + """ + Get the registered method handlers + :return: method handlers dictionary + :rtype: dict[str, RpcMethodHandler] + """ return self._method_handlers From 5a4b4eb3f095ead09df8e77fd46d6f3602f3f175 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 6 Sep 2025 02:12:46 +0000 Subject: [PATCH 30/40] added more visibility in the codebase --- samples/llm/chat_pb2.py | 12 +- src/dubbo/codec/json_codec/__init__.py | 4 +- .../codec/json_codec/json_codec_handler.py | 113 ++++++++++++------ 3 files changed, 80 insertions(+), 49 deletions(-) diff --git a/samples/llm/chat_pb2.py b/samples/llm/chat_pb2.py index 90de97f..de9488e 100644 --- a/samples/llm/chat_pb2.py +++ b/samples/llm/chat_pb2.py @@ -1,15 +1,13 @@ +# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: chat.proto # Protobuf Python Version: 4.25.3 """Generated protocol buffer code.""" -from google.protobuf import ( - descriptor as _descriptor, - descriptor_pool as _descriptor_pool, - symbol_database as _symbol_database, -) +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder - # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -22,7 +20,7 @@ _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "chat_pb2", _globals) -if not _descriptor._USE_C_DESCRIPTORS: +if _descriptor._USE_C_DESCRIPTORS == False: _globals["DESCRIPTOR"]._options = None _globals["DESCRIPTOR"]._serialized_options = b"B\tChatProtoP\001" _globals["_CHATREQUEST"]._serialized_start = 48 diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py index 55458d5..15469f7 100644 --- a/src/dubbo/codec/json_codec/__init__.py +++ b/src/dubbo/codec/json_codec/__init__.py @@ -26,7 +26,6 @@ from .enum_handler import EnumHandler from .dataclass_handler import DataclassHandler from .json_codec_handler import JsonTransportCodec -from .json_codec import JsonTransportCodecBridge __all__ = [ "JsonCodec", @@ -41,6 +40,5 @@ "SimpleTypesHandler", "EnumHandler", "DataclassHandler", - "JsonTransportCodec", - "JsonTransportCodecBridge", + "JsonTransportCodec" ] diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index 15df8bf..f1e6f4e 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -21,13 +21,6 @@ StandardJsonCodec, OrJsonCodec, UJsonCodec, - DateTimeHandler, - PydanticHandler, - CollectionHandler, - DecimalHandler, - SimpleTypesHandler, - EnumHandler, - DataclassHandler, ) __all__ = ["JsonTransportCodec", "SerializationException", "DeserializationException"] @@ -35,19 +28,25 @@ class SerializationException(Exception): """Exception raised during serialization""" - - pass + + def __init__(self, message: str): + super().__init__(message) + self.message = message class DeserializationException(Exception): """Exception raised during deserialization""" - - pass - + + def __init__(self, message: str): + super().__init__(message) + self.message = message class JsonTransportCodec: """ - JSON Transport Codec + JSON Transport Codec with integrated encoder/decoder functionality. + + This class serves as both a transport codec and provides encoder/decoder + interface compatibility for services that expect separate encoder/decoder objects. """ def __init__( @@ -58,6 +57,14 @@ def __init__( strict_validation: bool = True, **kwargs, ): + """ + Initialize the JSON transport codec. + + :param parameter_types: List of parameter types for the method. + :param return_type: Return type for the method. + :param maximum_depth: Maximum serialization depth. + :param strict_validation: Whether to use strict validation. + """ self.parameter_types = parameter_types or [] self.return_type = return_type self.maximum_depth = maximum_depth @@ -70,7 +77,6 @@ def __init__( def _setup_json_codecs(self) -> List[JsonCodec]: """ Setup JSON codecs in priority order. - Following the compression pattern: try fastest first, fallback to standard. """ codecs = [] @@ -92,29 +98,24 @@ def _setup_json_codecs(self) -> List[JsonCodec]: def _setup_type_handlers(self) -> List[TypeHandler]: """ Setup type handlers for different object types. - Similar to compression - each handler is independent and focused. """ handlers = [] - # Add all available handlers - handlers.append(DateTimeHandler()) + from dubbo.extension import extensionLoader - pydantic_handler = PydanticHandler() - if pydantic_handler.available: - handlers.append(pydantic_handler) - - handlers.extend( - [ - DecimalHandler(), - CollectionHandler(), - SimpleTypesHandler(), - EnumHandler(), - DataclassHandler(), - ] - ) + handler_names = ["datetime", "pydantic", "decimal", "enum", "simple", "dataclass", "collection"] + for name in handler_names: + try: + plugin_class = extensionLoader.get_extension(TypeHandler, name) + if plugin_class: + plugin_instance = plugin_class() + handlers.append(plugin_instance) + except Exception as e: + print(f"Warning: Could not load type handler plugin '{name}': {e}") return handlers + # Core encoding/decoding methods def encode_parameters(self, *arguments) -> bytes: """ Encode parameters to JSON bytes. @@ -168,6 +169,48 @@ def decode_return_value(self, data: bytes) -> Any: except Exception as e: raise DeserializationException(f"Return value decoding failed: {e}") from e + # Encoder/Decoder interface compatibility methods + def encoder(self): + """ + Get the parameter encoder instance (returns self for compatibility). + + :return: Self as encoder. + :rtype: JsonTransportCodec + """ + return self + + def decoder(self): + """ + Get the return value decoder instance (returns self for compatibility). + + :return: Self as decoder. + :rtype: JsonTransportCodec + """ + return self + + def encode(self, arguments: tuple) -> bytes: + """ + Encode method for encoder interface compatibility. + + :param arguments: The method arguments to encode. + :type arguments: tuple + :return: Encoded parameter bytes. + :rtype: bytes + """ + return self.encode_parameters(*arguments) + + def decode(self, data: bytes) -> Any: + """ + Decode method for decoder interface compatibility. + + :param data: The bytes to decode. + :type data: bytes + :return: Decoded return value. + :rtype: Any + """ + return self.decode_return_value(data) + + # Internal serialization methods def _serialize_object(self, obj: Any, depth: int = 0) -> Any: """ Serialize an object using the appropriate type handler. @@ -268,19 +311,15 @@ def _reconstruct_objects(self, data: Any) -> Any: # Handle special serialized objects if "__datetime__" in data: from datetime import datetime - return datetime.fromisoformat(data["__datetime__"]) elif "__date__" in data: from datetime import date - return date.fromisoformat(data["__date__"]) elif "__time__" in data: from datetime import time - return time.fromisoformat(data["__time__"]) elif "__decimal__" in data: from decimal import Decimal - return Decimal(data["__decimal__"]) elif "__set__" in data: return set(self._reconstruct_objects(item) for item in data["__set__"]) @@ -288,11 +327,9 @@ def _reconstruct_objects(self, data: Any) -> Any: return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) elif "__uuid__" in data: from uuid import UUID - return UUID(data["__uuid__"]) elif "__path__" in data: from pathlib import Path - return Path(data["__path__"]) elif "__pydantic_model__" in data and "__model_data__" in data: return self._reconstruct_pydantic_model(data) @@ -312,7 +349,6 @@ def _reconstruct_pydantic_model(self, data: dict) -> Any: module_name, class_name = model_path.rsplit(".", 1) import importlib - module = importlib.import_module(module_name) model_class = getattr(module, class_name) @@ -326,7 +362,6 @@ def _reconstruct_dataclass(self, data: dict) -> Any: module_name, class_name = data["__dataclass__"].rsplit(".", 1) import importlib - module = importlib.import_module(module_name) cls = getattr(module, class_name) @@ -338,8 +373,8 @@ def _reconstruct_enum(self, data: dict) -> Any: module_name, class_name = data["__enum__"].rsplit(".", 1) import importlib - module = importlib.import_module(module_name) cls = getattr(module, class_name) return cls(data["value"]) + From fc96853d931f6fd35368a4faed2b6ea81aa96091 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 6 Sep 2025 02:13:03 +0000 Subject: [PATCH 31/40] remove the file json_codec --- src/dubbo/codec/json_codec/json_codec.py | 130 ----------------------- 1 file changed, 130 deletions(-) delete mode 100644 src/dubbo/codec/json_codec/json_codec.py diff --git a/src/dubbo/codec/json_codec/json_codec.py b/src/dubbo/codec/json_codec/json_codec.py deleted file mode 100644 index f881e42..0000000 --- a/src/dubbo/codec/json_codec/json_codec.py +++ /dev/null @@ -1,130 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, List, Optional, Type -from dubbo.codec.json_codec import JsonTransportCodec - -__all__ = ["JsonTransportCodecBridge", "JsonParameterEncoder", "JsonReturnDecoder"] - - -class JsonParameterEncoder: - """ - Parameter encoder wrapper for JsonTransportCodec. - """ - - def __init__(self, codec: JsonTransportCodec): - self._codec = codec - - def encode(self, arguments: tuple) -> bytes: - """ - Encode method parameters. - - :param arguments: The method arguments to encode. - :type arguments: tuple - :return: Encoded parameter bytes. - :rtype: bytes - """ - return self._codec.encode_parameters(*arguments) - - -class JsonReturnDecoder: - """ - Return value decoder wrapper for JsonTransportCodec. - """ - - def __init__(self, codec: JsonTransportCodec): - self._codec = codec - - def decode(self, data: bytes) -> Any: - """ - Decode method return value. - - :param data: The bytes to decode. - :type data: bytes - :return: Decoded return value. - :rtype: Any - """ - return self._codec.decode_return_value(data) - - -class JsonTransportCodecBridge: - """ - Bridge class that adapts JsonTransportCodec to work with DubboSerializationService. - - This maintains compatibility with the existing extension loader system while - using the clean new codec architecture internally. - """ - - def __init__( - self, - parameter_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, - maximum_depth: int = 100, - strict_validation: bool = True, - **kwargs, - ): - """ - Initialize the codec bridge. - - :param parameter_types: List of parameter types for the method. - :param return_type: Return type for the method. - :param maximum_depth: Maximum serialization depth. - :param strict_validation: Whether to use strict validation. - """ - self._codec = JsonTransportCodec( - parameter_types=parameter_types, - return_type=return_type, - maximum_depth=maximum_depth, - strict_validation=strict_validation, - **kwargs, - ) - - # Create encoder and decoder instances - self._encoder = JsonParameterEncoder(self._codec) - self._decoder = JsonReturnDecoder(self._codec) - - def encoder(self) -> JsonParameterEncoder: - """ - Get the parameter encoder instance. - - :return: The parameter encoder. - :rtype: JsonParameterEncoder - """ - return self._encoder - - def decoder(self) -> JsonReturnDecoder: - """ - Get the return value decoder instance. - - :return: The return value decoder. - :rtype: JsonReturnDecoder - """ - return self._decoder - - # Direct access methods for convenience - def encode_parameters(self, *arguments) -> bytes: - """Direct parameter encoding.""" - return self._codec.encode_parameters(*arguments) - - def decode_return_value(self, data: bytes) -> Any: - """Direct return value decoding.""" - return self._codec.decode_return_value(data) - - # Properties for access to internal codec if needed - @property - def codec(self) -> JsonTransportCodec: - """Access to the underlying codec.""" - return self._codec From 1dc61a0ca92850483cf9fe5f0a6aea3c24a7703a Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 6 Sep 2025 02:16:29 +0000 Subject: [PATCH 32/40] chat_pb2.py --- samples/llm/chat_pb2.py | 2 +- src/dubbo/codec/json_codec/__init__.py | 2 +- .../codec/json_codec/json_codec_handler.py | 27 ++++++++++++------- src/dubbo/extension/registries.py | 2 +- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/samples/llm/chat_pb2.py b/samples/llm/chat_pb2.py index de9488e..716bf7f 100644 --- a/samples/llm/chat_pb2.py +++ b/samples/llm/chat_pb2.py @@ -29,4 +29,4 @@ _globals["_CHATREPLY"]._serialized_end = 136 _globals["_DEEPSEEKAISERVICE"]._serialized_start = 138 _globals["_DEEPSEEKAISERVICE"]._serialized_end = 259 -# @@protoc_insertion_point(module_scope) +# @@protoc_insertion_point(module_scope) \ No newline at end of file diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py index 15469f7..f623aa4 100644 --- a/src/dubbo/codec/json_codec/__init__.py +++ b/src/dubbo/codec/json_codec/__init__.py @@ -40,5 +40,5 @@ "SimpleTypesHandler", "EnumHandler", "DataclassHandler", - "JsonTransportCodec" + "JsonTransportCodec", ] diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index f1e6f4e..04b7fd3 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -28,7 +28,7 @@ class SerializationException(Exception): """Exception raised during serialization""" - + def __init__(self, message: str): super().__init__(message) self.message = message @@ -36,15 +36,16 @@ def __init__(self, message: str): class DeserializationException(Exception): """Exception raised during deserialization""" - + def __init__(self, message: str): super().__init__(message) self.message = message + class JsonTransportCodec: """ JSON Transport Codec with integrated encoder/decoder functionality. - + This class serves as both a transport codec and provides encoder/decoder interface compatibility for services that expect separate encoder/decoder objects. """ @@ -59,7 +60,7 @@ def __init__( ): """ Initialize the JSON transport codec. - + :param parameter_types: List of parameter types for the method. :param return_type: Return type for the method. :param maximum_depth: Maximum serialization depth. @@ -173,7 +174,7 @@ def decode_return_value(self, data: bytes) -> Any: def encoder(self): """ Get the parameter encoder instance (returns self for compatibility). - + :return: Self as encoder. :rtype: JsonTransportCodec """ @@ -182,7 +183,7 @@ def encoder(self): def decoder(self): """ Get the return value decoder instance (returns self for compatibility). - + :return: Self as decoder. :rtype: JsonTransportCodec """ @@ -191,7 +192,7 @@ def decoder(self): def encode(self, arguments: tuple) -> bytes: """ Encode method for encoder interface compatibility. - + :param arguments: The method arguments to encode. :type arguments: tuple :return: Encoded parameter bytes. @@ -202,7 +203,7 @@ def encode(self, arguments: tuple) -> bytes: def decode(self, data: bytes) -> Any: """ Decode method for decoder interface compatibility. - + :param data: The bytes to decode. :type data: bytes :return: Decoded return value. @@ -311,15 +312,19 @@ def _reconstruct_objects(self, data: Any) -> Any: # Handle special serialized objects if "__datetime__" in data: from datetime import datetime + return datetime.fromisoformat(data["__datetime__"]) elif "__date__" in data: from datetime import date + return date.fromisoformat(data["__date__"]) elif "__time__" in data: from datetime import time + return time.fromisoformat(data["__time__"]) elif "__decimal__" in data: from decimal import Decimal + return Decimal(data["__decimal__"]) elif "__set__" in data: return set(self._reconstruct_objects(item) for item in data["__set__"]) @@ -327,9 +332,11 @@ def _reconstruct_objects(self, data: Any) -> Any: return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) elif "__uuid__" in data: from uuid import UUID + return UUID(data["__uuid__"]) elif "__path__" in data: from pathlib import Path + return Path(data["__path__"]) elif "__pydantic_model__" in data and "__model_data__" in data: return self._reconstruct_pydantic_model(data) @@ -349,6 +356,7 @@ def _reconstruct_pydantic_model(self, data: dict) -> Any: module_name, class_name = model_path.rsplit(".", 1) import importlib + module = importlib.import_module(module_name) model_class = getattr(module, class_name) @@ -362,6 +370,7 @@ def _reconstruct_dataclass(self, data: dict) -> Any: module_name, class_name = data["__dataclass__"].rsplit(".", 1) import importlib + module = importlib.import_module(module_name) cls = getattr(module, class_name) @@ -373,8 +382,8 @@ def _reconstruct_enum(self, data: dict) -> Any: module_name, class_name = data["__enum__"].rsplit(".", 1) import importlib + module = importlib.import_module(module_name) cls = getattr(module, class_name) return cls(data["value"]) - diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index 6d3edb1..7a2871b 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -136,7 +136,7 @@ class ExtendedRegistry: "collection": "dubbo.codec.json_codec.CollectionHandler", "enum": "dubbo.codec.json_codec.EnumHandler", "dataclass": "dubbo.codec.json_codec.DataclassHandler", - "simple": "dubbo.codec.json_codec.SimpleTypeHandler", + "simple": "dubbo.codec.json_codec.SimpleTypesHandler", "pydantic": "dubbo.codec.json_codec.PydanticHandler", }, ) From bfee72d590e8e00c4f848d6294560490310a71cf Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 6 Sep 2025 05:14:01 +0000 Subject: [PATCH 33/40] changing the deprecated type to inbuilt one --- src/dubbo/classes.py | 1 - src/dubbo/client.py | 26 ++++++++--------- src/dubbo/codec/__init__.py | 2 +- src/dubbo/codec/_interface.py | 16 +++++----- src/dubbo/codec/dubbo_codec.py | 29 ++++++++++--------- src/dubbo/codec/json_codec/__init__.py | 14 ++++----- src/dubbo/codec/json_codec/_interfaces.py | 6 ++-- .../codec/json_codec/collections_handler.py | 8 ++--- .../codec/json_codec/dataclass_handler.py | 8 ++--- .../codec/json_codec/datetime_handler.py | 10 +++---- src/dubbo/codec/json_codec/decimal_handler.py | 8 ++--- src/dubbo/codec/json_codec/enum_handler.py | 6 ++-- .../codec/json_codec/json_codec_handler.py | 24 +++++++-------- .../codec/json_codec/pydantic_handler.py | 12 ++++---- .../codec/json_codec/simple_types_handler.py | 4 +-- src/dubbo/codec/protobuf_codec/__init__.py | 6 ++-- .../protobuf_codec/betterproto_handler.py | 2 +- .../codec/protobuf_codec/primitive_handler.py | 3 +- .../codec/protobuf_codec/protobuf_codec.py | 20 ++++++------- .../codec/protobuf_codec/protoc_handler.py | 3 +- src/dubbo/extension/registries.py | 6 ++-- tests/json/json_test.py | 21 ++++---------- 22 files changed, 108 insertions(+), 127 deletions(-) diff --git a/src/dubbo/classes.py b/src/dubbo/classes.py index 754b348..a07f56f 100644 --- a/src/dubbo/classes.py +++ b/src/dubbo/classes.py @@ -16,7 +16,6 @@ import abc import threading -from abc import ABC, abstractmethod from typing import Any, Callable, Optional, Union from dubbo.types import DeserializingFunction, RpcType, RpcTypes, SerializingFunction diff --git a/src/dubbo/client.py b/src/dubbo/client.py index f947abf..390a28c 100644 --- a/src/dubbo/client.py +++ b/src/dubbo/client.py @@ -3,8 +3,6 @@ # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -15,10 +13,11 @@ # limitations under the License. import threading -from typing import Optional, List, Type +from typing import Optional from dubbo.bootstrap import Dubbo from dubbo.classes import MethodDescriptor +from dubbo.codec import DubboSerializationService from dubbo.configs import ReferenceConfig from dubbo.constants import common_constants from dubbo.extension import extensionLoader @@ -32,7 +31,6 @@ SerializingFunction, ) from dubbo.url import URL -from dubbo.codec import DubboSerializationService __all__ = ["Client"] @@ -87,8 +85,8 @@ def _create_rpc_callable( self, rpc_type: str, method_name: str, - params_types: List[Type], - return_type: Type, + params_types: list[type], + return_type: type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, @@ -130,8 +128,8 @@ def _create_rpc_callable( def unary( self, method_name: str, - params_types: List[Type], - return_type: Type, + params_types: list[type], + return_type: type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, @@ -152,8 +150,8 @@ def unary( def client_stream( self, method_name: str, - params_types: List[Type], - return_type: Type, + params_types: list[type], + return_type: type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, @@ -174,8 +172,8 @@ def client_stream( def server_stream( self, method_name: str, - params_types: List[Type], - return_type: Type, + params_types: list[type], + return_type: type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, @@ -196,8 +194,8 @@ def server_stream( def bi_stream( self, method_name: str, - params_types: List[Type], - return_type: Type, + params_types: list[type], + return_type: type, codec: Optional[str] = None, request_serializer: Optional[SerializingFunction] = None, response_deserializer: Optional[DeserializingFunction] = None, diff --git a/src/dubbo/codec/__init__.py b/src/dubbo/codec/__init__.py index 88fbaa8..fb17b9d 100644 --- a/src/dubbo/codec/__init__.py +++ b/src/dubbo/codec/__init__.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .dubbo_codec import DubboSerializationService from ._interface import Codec +from .dubbo_codec import DubboSerializationService __all__ = ["DubboSerializationService", "Codec"] diff --git a/src/dubbo/codec/_interface.py b/src/dubbo/codec/_interface.py index ad28448..4350b82 100644 --- a/src/dubbo/codec/_interface.py +++ b/src/dubbo/codec/_interface.py @@ -17,7 +17,7 @@ import abc import logging from dataclasses import dataclass -from typing import Any, Callable, Optional, List, Tuple, Type +from typing import Any, Callable, Optional __all__ = [ "ParameterDescriptor", @@ -47,7 +47,7 @@ class MethodDescriptor: function: Callable name: str - parameters: List[ParameterDescriptor] + parameters: list[ParameterDescriptor] return_parameter: ParameterDescriptor documentation: Optional[str] = None @@ -92,11 +92,11 @@ class SerializationEncoder(abc.ABC): """ @abc.abstractmethod - def encode(self, arguments: Tuple[Any, ...]) -> bytes: + def encode(self, arguments: tuple[Any, ...]) -> bytes: """ Encode arguments to bytes. :param arguments: The arguments to encode. - :type arguments: Tuple[Any, ...] + :type arguments: tuple[Any, ...] :return: The encoded bytes. :rtype: bytes """ @@ -125,11 +125,11 @@ class Codec(abc.ABC): Base codec interface for encoding and decoding data. """ - def __init__(self, model_type: Optional[Type[Any]] = None, **kwargs): + def __init__(self, model_type: Optional[type[Any]] = None, **kwargs): """ Initialize a codec :param model_type: Optional model type for structured encoding/decoding - :type model_type: Optional[Type[Any]] + :type model_type: Optional[type[Any]] :param kwargs: Additional codec configuration """ self.model_type = model_type @@ -138,7 +138,7 @@ def __init__(self, model_type: Optional[Type[Any]] = None, **kwargs): def encode(self, data: Any) -> bytes: """ Encode data into bytes - :param data: The data to encode + :param data: The data to encode. :type data: Any :return: Encoded byte representation :rtype: bytes @@ -149,7 +149,7 @@ def encode(self, data: Any) -> bytes: def decode(self, data: bytes) -> Any: """ Decode bytes into object - :param data: The bytes to decode + :param data: The bytes to decode. :type data: bytes :return: Decoded object :rtype: Any diff --git a/src/dubbo/codec/dubbo_codec.py b/src/dubbo/codec/dubbo_codec.py index c0c1e1c..ae5f51f 100644 --- a/src/dubbo/codec/dubbo_codec.py +++ b/src/dubbo/codec/dubbo_codec.py @@ -16,14 +16,15 @@ import inspect import logging -from typing import Any, Callable, Optional, List, Tuple +from typing import Any, Callable, Optional + from ._interface import ( - ParameterDescriptor, + Codec, MethodDescriptor, + ParameterDescriptor, SerializationDecoder, SerializationEncoder, TransportCodec, - Codec, ) __all__ = [ @@ -39,7 +40,7 @@ class DubboSerializationService: @staticmethod def create_transport_codec( transport_type: str = "json", - parameter_types: Optional[List[type]] = None, + parameter_types: Optional[list[type]] = None, return_type: Optional[type] = None, **codec_options, ) -> TransportCodec: @@ -47,7 +48,7 @@ def create_transport_codec( Create transport codec :param transport_type: The transport type (e.g., 'json', 'protobuf') - :param parameter_types: List of parameter types + :param parameter_types: list of parameter types :param return_type: Return value type :param codec_options: Additional codec options :return: Transport codec instance @@ -69,18 +70,18 @@ def create_transport_codec( @staticmethod def create_encoder_decoder_pair( transport_type: str, - parameter_types: Optional[List[type]] = None, + parameter_types: Optional[list[type]] = None, return_type: Optional[type] = None, **codec_options, - ) -> Tuple[SerializationEncoder, SerializationDecoder]: + ) -> tuple[SerializationEncoder, SerializationDecoder]: """ Create encoder and decoder instances :param transport_type: The transport type - :param parameter_types: List of parameter types + :param parameter_types: list of parameter types :param return_type: Return value type :param codec_options: Additional codec options - :return: Tuple of (encoder, decoder) + :return: tuple of (encoder, decoder) :raises ValueError: If codec returns None encoder/decoder :raises Exception: If creation fails """ @@ -107,18 +108,18 @@ def create_encoder_decoder_pair( @staticmethod def create_serialization_functions( transport_type: str, - parameter_types: Optional[List[type]] = None, + parameter_types: Optional[list[type]] = None, return_type: Optional[type] = None, **codec_options, - ) -> Tuple[Callable[..., bytes], Callable[[bytes], Any]]: + ) -> tuple[Callable[..., bytes], Callable[[bytes], Any]]: """ Create serializer and deserializer functions :param transport_type: The transport type - :param parameter_types: List of parameter types + :param parameter_types: list of parameter types :param return_type: Return value type :param codec_options: Additional codec options - :return: Tuple of (serializer_function, deserializer_function) + :return: tuple of (serializer_function, deserializer_function) :raises Exception: If creation fails """ try: @@ -157,7 +158,7 @@ def deserialize_method_return(data: bytes) -> Any: def create_method_descriptor( func: Callable, method_name: Optional[str] = None, - parameter_types: Optional[List[type]] = None, + parameter_types: Optional[list[type]] = None, return_type: Optional[type] = None, interface: Optional[Callable[..., Any]] = None, ) -> MethodDescriptor: diff --git a/src/dubbo/codec/json_codec/__init__.py b/src/dubbo/codec/json_codec/__init__.py index f623aa4..05b8599 100644 --- a/src/dubbo/codec/json_codec/__init__.py +++ b/src/dubbo/codec/json_codec/__init__.py @@ -15,17 +15,17 @@ # limitations under the License. from ._interfaces import JsonCodec, TypeHandler -from .standard_json import StandardJsonCodec -from .orjson_codec import OrJsonCodec -from .ujson_codec import UJsonCodec -from .datetime_handler import DateTimeHandler -from .pydantic_handler import PydanticHandler from .collections_handler import CollectionHandler +from .dataclass_handler import DataclassHandler +from .datetime_handler import DateTimeHandler from .decimal_handler import DecimalHandler -from .simple_types_handler import SimpleTypesHandler from .enum_handler import EnumHandler -from .dataclass_handler import DataclassHandler from .json_codec_handler import JsonTransportCodec +from .orjson_codec import OrJsonCodec +from .pydantic_handler import PydanticHandler +from .simple_types_handler import SimpleTypesHandler +from .standard_json import StandardJsonCodec +from .ujson_codec import UJsonCodec __all__ = [ "JsonCodec", diff --git a/src/dubbo/codec/json_codec/_interfaces.py b/src/dubbo/codec/json_codec/_interfaces.py index 4a3a36e..af2b1a7 100644 --- a/src/dubbo/codec/json_codec/_interfaces.py +++ b/src/dubbo/codec/json_codec/_interfaces.py @@ -15,7 +15,7 @@ # limitations under the License. import abc -from typing import Any, Dict +from typing import Any __all__ = ["JsonCodec", "TypeHandler"] @@ -86,13 +86,13 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: raise NotImplementedError() @abc.abstractmethod - def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + def serialize_to_dict(self, obj: Any) -> dict[str, Any]: """ Serialize the object into a dictionary representation. :param obj: The object to serialize. :type obj: Any :return: The dictionary representation of the object. - :rtype: Dict[str, Any] + :rtype: dict[str, Any] """ raise NotImplementedError() diff --git a/src/dubbo/codec/json_codec/collections_handler.py b/src/dubbo/codec/json_codec/collections_handler.py index ff18764..df710df 100644 --- a/src/dubbo/codec/json_codec/collections_handler.py +++ b/src/dubbo/codec/json_codec/collections_handler.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Union +from typing import Any, Union from dubbo.codec.json_codec import TypeHandler @@ -42,14 +42,14 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: """ return obj_type in (set, frozenset) - def serialize_to_dict(self, obj: Union[set, frozenset]) -> Dict[str, list]: + def serialize_to_dict(self, obj: Union[set, frozenset]) -> dict[str, list]: """ Serialize set/frozenset to dictionary representation. :param obj: The collection to serialize. :type obj: Union[set, frozenset] - :return: Dictionary representation with type marker. - :rtype: Dict[str, list] + :return: dictionary representation with type marker. + :rtype: dict[str, list] """ if isinstance(obj, frozenset): return {"__frozenset__": list(obj)} diff --git a/src/dubbo/codec/json_codec/dataclass_handler.py b/src/dubbo/codec/json_codec/dataclass_handler.py index 23fd226..95b4ac0 100644 --- a/src/dubbo/codec/json_codec/dataclass_handler.py +++ b/src/dubbo/codec/json_codec/dataclass_handler.py @@ -15,7 +15,7 @@ # limitations under the License. from dataclasses import asdict, is_dataclass -from typing import Any, Dict +from typing import Any from dubbo.codec.json_codec import TypeHandler @@ -43,13 +43,13 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: """ return is_dataclass(obj) - def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + def serialize_to_dict(self, obj: Any) -> dict[str, Any]: """ Serialize dataclass to dictionary representation. :param obj: The dataclass to serialize. :type obj: Any - :return: Dictionary with class path and field data. - :rtype: Dict[str, Any] + :return: dictionary with class path and field data. + :rtype: dict[str, Any] """ return {"__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "fields": asdict(obj)} diff --git a/src/dubbo/codec/json_codec/datetime_handler.py b/src/dubbo/codec/json_codec/datetime_handler.py index 51eb416..cf560a6 100644 --- a/src/dubbo/codec/json_codec/datetime_handler.py +++ b/src/dubbo/codec/json_codec/datetime_handler.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import datetime, date, time -from typing import Any, Dict, Union +from datetime import date, datetime, time +from typing import Any, Union from dubbo.codec.json_codec import TypeHandler @@ -42,14 +42,14 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: """ return isinstance(obj, (datetime, date, time)) - def serialize_to_dict(self, obj: Union[datetime, date, time]) -> Dict[str, Any]: + def serialize_to_dict(self, obj: Union[datetime, date, time]) -> dict[str, Any]: """ Serialize datetime objects to dictionary representation. :param obj: The datetime object to serialize. :type obj: Union[datetime, date, time] - :return: Dictionary representation with type markers. - :rtype: Dict[str, Any] + :return: dictionary representation with type markers. + :rtype: dict[str, Any] """ if isinstance(obj, datetime): return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} diff --git a/src/dubbo/codec/json_codec/decimal_handler.py b/src/dubbo/codec/json_codec/decimal_handler.py index 3ea87b1..0cadcbf 100644 --- a/src/dubbo/codec/json_codec/decimal_handler.py +++ b/src/dubbo/codec/json_codec/decimal_handler.py @@ -15,7 +15,7 @@ # limitations under the License. from decimal import Decimal -from typing import Any, Dict +from typing import Any from dubbo.codec.json_codec import TypeHandler @@ -43,13 +43,13 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: """ return obj_type is Decimal - def serialize_to_dict(self, obj: Decimal) -> Dict[str, str]: + def serialize_to_dict(self, obj: Decimal) -> dict[str, str]: """ Serialize Decimal to dictionary representation. :param obj: The Decimal to serialize. :type obj: Decimal - :return: Dictionary representation with string value. - :rtype: Dict[str, str] + :return: dictionary representation with string value. + :rtype: dict[str, str] """ return {"__decimal__": str(obj)} diff --git a/src/dubbo/codec/json_codec/enum_handler.py b/src/dubbo/codec/json_codec/enum_handler.py index 2cd4a6b..725ce90 100644 --- a/src/dubbo/codec/json_codec/enum_handler.py +++ b/src/dubbo/codec/json_codec/enum_handler.py @@ -15,7 +15,7 @@ # limitations under the License. from enum import Enum -from typing import Any, Dict +from typing import Any from dubbo.codec.json_codec import TypeHandler @@ -43,13 +43,13 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: """ return isinstance(obj, Enum) - def serialize_to_dict(self, obj: Enum) -> Dict[str, Any]: + def serialize_to_dict(self, obj: Enum) -> dict[str, Any]: """ Serialize Enum to dictionary representation. :param obj: The Enum to serialize. :type obj: Enum :return: Dictionary with enum class path and value. - :rtype: Dict[str, Any] + :rtype: dict[str, Any] """ return {"__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "value": obj.value} diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index 04b7fd3..918cdb6 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -14,14 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Type, List, Optional -from dubbo.codec.json_codec import ( - JsonCodec, - TypeHandler, - StandardJsonCodec, - OrJsonCodec, - UJsonCodec, -) +from typing import Any, Optional + +from .orjson_codec import OrJsonCodec +from .ujson_codec import UJsonCodec +from .standard_json import StandardJsonCodec +from ._interfaces import JsonCodec, TypeHandler __all__ = ["JsonTransportCodec", "SerializationException", "DeserializationException"] @@ -52,8 +50,8 @@ class JsonTransportCodec: def __init__( self, - parameter_types: Optional[List[Type]] = None, - return_type: Optional[Type] = None, + parameter_types: Optional[list[type]] = None, + return_type: Optional[type] = None, maximum_depth: int = 100, strict_validation: bool = True, **kwargs, @@ -61,7 +59,7 @@ def __init__( """ Initialize the JSON transport codec. - :param parameter_types: List of parameter types for the method. + :param parameter_types: list of parameter types for the method. :param return_type: Return type for the method. :param maximum_depth: Maximum serialization depth. :param strict_validation: Whether to use strict validation. @@ -75,7 +73,7 @@ def __init__( self._json_codecs = self._setup_json_codecs() self._type_handlers = self._setup_type_handlers() - def _setup_json_codecs(self) -> List[JsonCodec]: + def _setup_json_codecs(self) -> list[JsonCodec]: """ Setup JSON codecs in priority order. """ @@ -96,7 +94,7 @@ def _setup_json_codecs(self) -> List[JsonCodec]: return codecs - def _setup_type_handlers(self) -> List[TypeHandler]: + def _setup_type_handlers(self) -> list[TypeHandler]: """ Setup type handlers for different object types. """ diff --git a/src/dubbo/codec/json_codec/pydantic_handler.py b/src/dubbo/codec/json_codec/pydantic_handler.py index 90a661c..92410dd 100644 --- a/src/dubbo/codec/json_codec/pydantic_handler.py +++ b/src/dubbo/codec/json_codec/pydantic_handler.py @@ -3,8 +3,6 @@ # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -14,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, List, Type +from typing import Any, Optional from dubbo.codec.json_codec import TypeHandler @@ -52,14 +50,14 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: """ return self.available and isinstance(obj, self.BaseModel) - def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: + def serialize_to_dict(self, obj: Any) -> dict[str, Any]: """ Serialize Pydantic model to dictionary representation. :param obj: The Pydantic model to serialize. :type obj: BaseModel :return: Dictionary representation with model metadata. - :rtype: Dict[str, Any] + :rtype: dict[str, Any] """ if not self.available: raise ImportError("Pydantic not available") @@ -75,12 +73,12 @@ def serialize_to_dict(self, obj: Any) -> Dict[str, Any]: "__model_data__": model_data, } - def create_parameter_model(self, parameter_types: Optional[List[Type]] = None): + def create_parameter_model(self, parameter_types: Optional[list[type]] = None): """ Create a Pydantic model for parameter wrapping. :param parameter_types: List of parameter types to wrap. - :type parameter_types: Optional[List[Type]] + :type parameter_types: Optional[list[type]] :return: Dynamically created Pydantic model or None. """ if not self.available or parameter_types is None: diff --git a/src/dubbo/codec/json_codec/simple_types_handler.py b/src/dubbo/codec/json_codec/simple_types_handler.py index 05dfb15..5dc77a1 100644 --- a/src/dubbo/codec/json_codec/simple_types_handler.py +++ b/src/dubbo/codec/json_codec/simple_types_handler.py @@ -15,7 +15,7 @@ # limitations under the License. from pathlib import Path -from typing import Any, Dict, Union +from typing import Any, Union from uuid import UUID from dubbo.codec.json_codec import TypeHandler @@ -43,7 +43,7 @@ def can_serialize_type(self, obj: Any, obj_type: type) -> bool: """ return obj_type in (UUID, Path) or isinstance(obj, Path) - def serialize_to_dict(self, obj: Union[UUID, Path]) -> Dict[str, str]: + def serialize_to_dict(self, obj: Union[UUID, Path]) -> dict[str, str]: """ Serialize UUID or Path to dictionary representation. diff --git a/src/dubbo/codec/protobuf_codec/__init__.py b/src/dubbo/codec/protobuf_codec/__init__.py index dd17f37..4bbf7b4 100644 --- a/src/dubbo/codec/protobuf_codec/__init__.py +++ b/src/dubbo/codec/protobuf_codec/__init__.py @@ -14,11 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder -from ._interface import ProtobufEncoder, ProtobufDecoder +from ._interface import ProtobufDecoder, ProtobufEncoder from .betterproto_handler import BetterprotoMessageHandler -from .protoc_handler import GoogleProtobufMessageHandler from .primitive_handler import PrimitiveHandler +from .protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder, ProtobufTransportEncoder +from .protoc_handler import GoogleProtobufMessageHandler __all__ = [ "ProtobufTransportCodec", diff --git a/src/dubbo/codec/protobuf_codec/betterproto_handler.py b/src/dubbo/codec/protobuf_codec/betterproto_handler.py index 97b0487..605cad8 100644 --- a/src/dubbo/codec/protobuf_codec/betterproto_handler.py +++ b/src/dubbo/codec/protobuf_codec/betterproto_handler.py @@ -17,7 +17,7 @@ import json from typing import Any, Optional -from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException +from ._interface import DeserializationException, ProtobufDecoder, ProtobufEncoder, SerializationException try: import betterproto diff --git a/src/dubbo/codec/protobuf_codec/primitive_handler.py b/src/dubbo/codec/protobuf_codec/primitive_handler.py index 79142df..8a68091 100644 --- a/src/dubbo/codec/protobuf_codec/primitive_handler.py +++ b/src/dubbo/codec/protobuf_codec/primitive_handler.py @@ -17,8 +17,7 @@ import json from typing import Any, Optional -from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException - +from ._interface import DeserializationException, ProtobufDecoder, ProtobufEncoder, SerializationException __all__ = ["PrimitiveHandler"] diff --git a/src/dubbo/codec/protobuf_codec/protobuf_codec.py b/src/dubbo/codec/protobuf_codec/protobuf_codec.py index 871f53c..c175bb3 100644 --- a/src/dubbo/codec/protobuf_codec/protobuf_codec.py +++ b/src/dubbo/codec/protobuf_codec/protobuf_codec.py @@ -3,8 +3,6 @@ # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # @@ -14,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional, List +from typing import Any, Optional -from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException +from ._interface import DeserializationException, ProtobufDecoder, ProtobufEncoder, SerializationException __all__ = ["ProtobufTransportCodec"] @@ -24,7 +22,7 @@ class ProtobufTransportEncoder: """Protobuf encoder for parameters""" - def __init__(self, handlers: List[ProtobufEncoder], parameter_types: Optional[List[type]] = None): + def __init__(self, handlers: list[ProtobufEncoder], parameter_types: Optional[list[type]] = None): self._handlers = handlers self._parameter_types = parameter_types or [] @@ -59,7 +57,7 @@ def _encode_single(self, argument: Any) -> bytes: class ProtobufTransportDecoder: """Protobuf decoder for return values""" - def __init__(self, handlers: List[ProtobufDecoder], return_type: Optional[type] = None): + def __init__(self, handlers: list[ProtobufDecoder], return_type: Optional[type] = None): self._handlers = handlers self._return_type = return_type @@ -87,7 +85,7 @@ class ProtobufTransportCodec: def __init__( self, - parameter_types: Optional[List[type]] = None, + parameter_types: Optional[list[type]] = None, return_type: Optional[type] = None, **kwargs, ): @@ -95,8 +93,8 @@ def __init__( self._return_type = return_type # Initialize handlers - self._encoders: List[ProtobufEncoder] = [] - self._decoders: List[ProtobufDecoder] = [] + self._encoders: list[ProtobufEncoder] = [] + self._decoders: list[ProtobufDecoder] = [] # Load default handlers self._load_default_handlers() @@ -162,10 +160,10 @@ def register_decoder(self, decoder: ProtobufDecoder): """Register a custom decoder""" self._decoders.append(decoder) - def get_encoders(self) -> List[ProtobufEncoder]: + def get_encoders(self) -> list[ProtobufEncoder]: """Get all registered encoders""" return self._encoders.copy() - def get_decoders(self) -> List[ProtobufDecoder]: + def get_decoders(self) -> list[ProtobufDecoder]: """Get all registered decoders""" return self._decoders.copy() diff --git a/src/dubbo/codec/protobuf_codec/protoc_handler.py b/src/dubbo/codec/protobuf_codec/protoc_handler.py index 67629c7..5dda34e 100644 --- a/src/dubbo/codec/protobuf_codec/protoc_handler.py +++ b/src/dubbo/codec/protobuf_codec/protoc_handler.py @@ -15,7 +15,8 @@ # limitations under the License. from typing import Any, Optional -from ._interface import ProtobufEncoder, ProtobufDecoder, SerializationException, DeserializationException + +from ._interface import DeserializationException, ProtobufDecoder, ProtobufEncoder, SerializationException try: from google.protobuf.message import Message as GoogleMessage diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index 7a2871b..cc177df 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -18,13 +18,13 @@ from typing import Any from dubbo.cluster import LoadBalance +from dubbo.codec import Codec +from dubbo.codec.json_codec import TypeHandler +from dubbo.codec.protobuf_codec import ProtobufEncoder from dubbo.compression import Compressor, Decompressor from dubbo.protocol import Protocol from dubbo.registry import RegistryFactory from dubbo.remoting import Transporter -from dubbo.codec import Codec -from dubbo.codec.json_codec import TypeHandler -from dubbo.codec.protobuf_codec import ProtobufEncoder @dataclass diff --git a/tests/json/json_test.py b/tests/json/json_test.py index 5bc7b9d..e83316e 100644 --- a/tests/json/json_test.py +++ b/tests/json/json_test.py @@ -8,13 +8,13 @@ # # http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, asdict from datetime import date, datetime, time from decimal import Decimal from enum import Enum @@ -22,7 +22,6 @@ from uuid import UUID import pytest -from pydantic import BaseModel from dubbo.codec.json_codec import JsonTransportCodec @@ -39,11 +38,6 @@ class Color(Enum): GREEN = "green" -class SamplePydanticModel(BaseModel): - name: str - value: int - - # List of test cases: (input_value, expected_type_after_decoding) test_cases = [ ("simple string", str), @@ -59,7 +53,7 @@ class SamplePydanticModel(BaseModel): (UUID("12345678-1234-5678-1234-567812345678"), UUID), (Path("/tmp/file.txt"), Path), (Color.RED, Color), - (SamplePydanticModel(name="test", value=42), SamplePydanticModel), + (SampleDataClass(field1=1, field2="abc"), SampleDataClass), ] @@ -74,13 +68,8 @@ def test_json_codec_roundtrip(value, expected_type): # Decode decoded = codec.decode_return_value(encoded) - # For pydantic models, compare dict representation - if hasattr(value, "dict") and callable(value.dict): - assert decoded.dict() == value.dict() # For dataclass, compare asdict - elif hasattr(value, "__dataclass_fields__"): - from dataclasses import asdict - + if hasattr(value, "__dataclass_fields__"): assert asdict(decoded) == asdict(value) # For sets/frozensets, compare as sets elif isinstance(value, (set, frozenset)): From 96b02ed2ec4d6c5eaddf17fa4f3cb11619dd55ce Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 6 Sep 2025 05:15:58 +0000 Subject: [PATCH 34/40] test being change --- tests/protobuf/protobuf_test.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/tests/protobuf/protobuf_test.py b/tests/protobuf/protobuf_test.py index 6c331c5..0947a11 100644 --- a/tests/protobuf/protobuf_test.py +++ b/tests/protobuf/protobuf_test.py @@ -37,21 +37,7 @@ def test_protobuf_roundtrip_message(): decoded = codec.decode_return_value(reply_bytes) assert isinstance(decoded, GreeterReply) assert decoded.message == "Hello Alice" - - -def test_protobuf_from_dict(): - codec = ProtobufTransportCodec(parameter_type=GreeterRequest, return_type=GreeterReply) - - # Dict instead of message instance - encoded = codec.encode_parameter({"name": "Bob"}) - assert isinstance(encoded, bytes) - - # To decode back to the parameter type, we need a decoder configured for GreeterRequest - param_decoder = ProtobufTransportDecoder(target_type=GreeterRequest) - req = param_decoder.decode(encoded) - assert isinstance(req, GreeterRequest) - assert req.name == "Bob" - + def test_protobuf_primitive_fallback(): codec = ProtobufTransportCodec(parameter_type=str, return_type=str) From 91a4ba87acd2c5b1adca38fff6a759f2b5c80488 Mon Sep 17 00:00:00 2001 From: Aditya Yadav <166515021+aditya0yadav@users.noreply.github.com> Date: Sat, 6 Sep 2025 05:38:26 +0000 Subject: [PATCH 35/40] added more pytest for protobuf --- tests/protobuf/generated/greet_pb2.py | 27 ++++++++++ tests/protobuf/greet.proto | 16 ------ tests/protobuf/protobuf_test.py | 74 +++++++++++++++++++-------- 3 files changed, 79 insertions(+), 38 deletions(-) create mode 100644 tests/protobuf/generated/greet_pb2.py diff --git a/tests/protobuf/generated/greet_pb2.py b/tests/protobuf/generated/greet_pb2.py new file mode 100644 index 0000000..4231dcb --- /dev/null +++ b/tests/protobuf/generated/greet_pb2.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: greet.proto +"""Generated protocol buffer code.""" + +from google.protobuf.internal import builder as _builder +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n\x0bgreet.proto\x12\rprotobuf_test"\x1e\n\x0eGreeterRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"\x1f\n\x0cGreeterReply\x12\x0f\n\x07message\x18\x01 \x01(\tb\x06proto3' +) + +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "greet_pb2", globals()) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + _GREETERREQUEST._serialized_start = 30 + _GREETERREQUEST._serialized_end = 60 + _GREETERREPLY._serialized_start = 62 + _GREETERREPLY._serialized_end = 93 +# @@protoc_insertion_point(module_scope) diff --git a/tests/protobuf/greet.proto b/tests/protobuf/greet.proto index 9c16bbc..5b453a7 100644 --- a/tests/protobuf/greet.proto +++ b/tests/protobuf/greet.proto @@ -1,19 +1,3 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - syntax = "proto3"; package protobuf_test; diff --git a/tests/protobuf/protobuf_test.py b/tests/protobuf/protobuf_test.py index 0947a11..ed4c0fd 100644 --- a/tests/protobuf/protobuf_test.py +++ b/tests/protobuf/protobuf_test.py @@ -13,39 +13,69 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# -from generated.protobuf_test import GreeterReply, GreeterRequest - -from dubbo.codec.protobuf_codec import ProtobufTransportCodec, ProtobufTransportDecoder +import pytest +from dubbo.codec.protobuf_codec import ProtobufTransportCodec +from dubbo.codec.protobuf_codec import PrimitiveHandler +from dubbo.codec.protobuf_codec import GoogleProtobufMessageHandler +from dubbo.codec.protobuf_codec.protobuf_codec import SerializationException, DeserializationException -def test_protobuf_roundtrip_message(): - codec = ProtobufTransportCodec(parameter_type=GreeterRequest, return_type=GreeterReply) - # Create a request - req = GreeterRequest(name="Alice") +def test_primitive_roundtrip_string(): + codec = ProtobufTransportCodec(parameter_types=[str], return_type=str) # Encode + encoded = codec.encode_parameter("hello world") + assert isinstance(encoded, bytes) + + # Decode + decoded = codec.decode_return_value(encoded) + assert decoded == "hello world" + + +def test_primitive_roundtrip_int(): + codec = ProtobufTransportCodec(parameter_types=[int], return_type=int) + + encoded = codec.encode_parameter(12345) + decoded = codec.decode_return_value(encoded) + + assert isinstance(decoded, int) + assert decoded == 12345 + + +def test_primitive_invalid_type_raises(): + codec = ProtobufTransportCodec(parameter_types=[dict], return_type=dict) + + with pytest.raises(SerializationException): + codec.encode_parameter({"a": 1}) + + +def test_decode_with_no_return_type_raises(): + codec = ProtobufTransportCodec(parameter_types=[str], return_type=None) + + data = PrimitiveHandler().encode("hello", str) + + with pytest.raises(DeserializationException): + codec.decode_return_value(data) + + +@pytest.mark.skipif(not GoogleProtobufMessageHandler.__module__, reason="google.protobuf not available") +def test_google_protobuf_roundtrip(): + from generated.greet_pb2 import GreeterRequest, GreeterReply + + codec = ProtobufTransportCodec(parameter_types=[GreeterRequest], return_type=GreeterReply) + + req = GreeterRequest(name="Alice") encoded = codec.encode_parameter(req) + assert isinstance(encoded, bytes) - # Fake a server reply + # Fake server response reply = GreeterReply(message="Hello Alice") - reply_bytes = bytes(reply) + reply_bytes = reply.SerializeToString() - # Decode return value decoded = codec.decode_return_value(reply_bytes) assert isinstance(decoded, GreeterReply) assert decoded.message == "Hello Alice" - - -def test_protobuf_primitive_fallback(): - codec = ProtobufTransportCodec(parameter_type=str, return_type=str) - - encoded = codec.encode_parameter("simple string") - assert isinstance(encoded, bytes) - - # Decode back - decoded = codec.decode_return_value(encoded) - assert isinstance(decoded, str) - assert decoded == "simple string" From ebbb532849672bf5189b1ecca5168bc379e5408c Mon Sep 17 00:00:00 2001 From: aditya Date: Sat, 6 Sep 2025 23:51:28 +0530 Subject: [PATCH 36/40] fix the bug related to json transportcodecbridge --- src/dubbo/extension/registries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dubbo/extension/registries.py b/src/dubbo/extension/registries.py index cc177df..08b40a7 100644 --- a/src/dubbo/extension/registries.py +++ b/src/dubbo/extension/registries.py @@ -112,7 +112,7 @@ class ExtendedRegistry: codecRegistry = ExtendedRegistry( interface=Codec, impls={ - "json": "dubbo.codec.json_codec.JsonTransportCodecBridge", + "json": "dubbo.codec.json_codec.JsonTransportCodec", "protobuf": "dubbo.codec.protobuf_codec.ProtobufTransportCodec", }, ) From 552137c8fea9f206ab9626c0029d9dbccb8e1831 Mon Sep 17 00:00:00 2001 From: aditya Date: Sun, 7 Sep 2025 01:35:15 +0530 Subject: [PATCH 37/40] remove the use the json as a fallback in the handlers.py --- src/dubbo/proxy/handlers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dubbo/proxy/handlers.py b/src/dubbo/proxy/handlers.py index 578d654..498ce3d 100644 --- a/src/dubbo/proxy/handlers.py +++ b/src/dubbo/proxy/handlers.py @@ -127,7 +127,7 @@ def _create_method_descriptor( """ if request_deserializer is None or response_serializer is None: codec_kwargs = { - "transport_type": codec or "json", + "transport_type": codec, "parameter_types": params_types, "return_type": return_type, **kwargs, From dea7c78f5d60e96372f7dcdcf9971096709459ee Mon Sep 17 00:00:00 2001 From: aditya Date: Sun, 7 Sep 2025 05:38:19 +0530 Subject: [PATCH 38/40] changing the marker according to real time market state --- .../codec/json_codec/collections_handler.py | 4 +- .../codec/json_codec/dataclass_handler.py | 2 +- .../codec/json_codec/datetime_handler.py | 14 +- src/dubbo/codec/json_codec/decimal_handler.py | 2 +- src/dubbo/codec/json_codec/enum_handler.py | 2 +- .../codec/json_codec/json_codec_handler.py | 123 ++++++++++++------ src/dubbo/codec/json_codec/orjson_codec.py | 23 ++-- .../codec/json_codec/simple_types_handler.py | 4 +- src/dubbo/codec/json_codec/ujson_codec.py | 23 ++-- 9 files changed, 126 insertions(+), 71 deletions(-) diff --git a/src/dubbo/codec/json_codec/collections_handler.py b/src/dubbo/codec/json_codec/collections_handler.py index df710df..0e569cd 100644 --- a/src/dubbo/codec/json_codec/collections_handler.py +++ b/src/dubbo/codec/json_codec/collections_handler.py @@ -52,6 +52,6 @@ def serialize_to_dict(self, obj: Union[set, frozenset]) -> dict[str, list]: :rtype: dict[str, list] """ if isinstance(obj, frozenset): - return {"__frozenset__": list(obj)} + return {"$frozenset": list(obj)} else: - return {"__set__": list(obj)} + return {"$set": list(obj)} diff --git a/src/dubbo/codec/json_codec/dataclass_handler.py b/src/dubbo/codec/json_codec/dataclass_handler.py index 95b4ac0..6ce2eae 100644 --- a/src/dubbo/codec/json_codec/dataclass_handler.py +++ b/src/dubbo/codec/json_codec/dataclass_handler.py @@ -52,4 +52,4 @@ def serialize_to_dict(self, obj: Any) -> dict[str, Any]: :return: dictionary with class path and field data. :rtype: dict[str, Any] """ - return {"__dataclass__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "fields": asdict(obj)} + return {"$dataclass": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "$fields": asdict(obj)} diff --git a/src/dubbo/codec/json_codec/datetime_handler.py b/src/dubbo/codec/json_codec/datetime_handler.py index cf560a6..a20a23c 100644 --- a/src/dubbo/codec/json_codec/datetime_handler.py +++ b/src/dubbo/codec/json_codec/datetime_handler.py @@ -52,10 +52,18 @@ def serialize_to_dict(self, obj: Union[datetime, date, time]) -> dict[str, Any]: :rtype: dict[str, Any] """ if isinstance(obj, datetime): - return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + # Convert to ISO format with Z suffix for UTC + iso_string = obj.isoformat() + if obj.tzinfo is None: + # Assume naive datetime is UTC and add Z + iso_string += "Z" + elif str(obj.tzinfo) == "UTC" or obj.utcoffset().total_seconds() == 0: + # Replace +00:00 with Z for UTC + iso_string = iso_string.replace("+00:00", "Z") + return {"$date": iso_string} elif isinstance(obj, date): - return {"__date__": obj.isoformat()} + return {"$dateOnly": obj.isoformat()} elif isinstance(obj, time): - return {"__time__": obj.isoformat()} + return {"$timeOnly": obj.isoformat()} else: raise ValueError(f"Unsupported datetime type: {type(obj)}") diff --git a/src/dubbo/codec/json_codec/decimal_handler.py b/src/dubbo/codec/json_codec/decimal_handler.py index 0cadcbf..60c36f9 100644 --- a/src/dubbo/codec/json_codec/decimal_handler.py +++ b/src/dubbo/codec/json_codec/decimal_handler.py @@ -52,4 +52,4 @@ def serialize_to_dict(self, obj: Decimal) -> dict[str, str]: :return: dictionary representation with string value. :rtype: dict[str, str] """ - return {"__decimal__": str(obj)} + return {"$decimal": str(obj)} diff --git a/src/dubbo/codec/json_codec/enum_handler.py b/src/dubbo/codec/json_codec/enum_handler.py index 725ce90..980c3bc 100644 --- a/src/dubbo/codec/json_codec/enum_handler.py +++ b/src/dubbo/codec/json_codec/enum_handler.py @@ -52,4 +52,4 @@ def serialize_to_dict(self, obj: Enum) -> dict[str, Any]: :return: Dictionary with enum class path and value. :rtype: dict[str, Any] """ - return {"__enum__": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "value": obj.value} + return {"$enum": f"{obj.__class__.__module__}.{obj.__class__.__qualname__}", "$value": obj.value} diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index 918cdb6..2c13dc8 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -16,10 +16,10 @@ from typing import Any, Optional +from ._interfaces import JsonCodec, TypeHandler from .orjson_codec import OrJsonCodec -from .ujson_codec import UJsonCodec from .standard_json import StandardJsonCodec -from ._interfaces import JsonCodec, TypeHandler +from .ujson_codec import UJsonCodec __all__ = ["JsonTransportCodec", "SerializationException", "DeserializationException"] @@ -151,19 +151,38 @@ def encode_parameters(self, *arguments) -> bytes: def decode_return_value(self, data: bytes) -> Any: """ - Decode return value from JSON bytes. - - :param data: The JSON bytes to decode. - :type data: bytes - :return: Decoded return value. - :rtype: Any + Decode return value from JSON bytes and validate against self.return_type. """ try: if not data: return None + # Step 1: Decode JSON bytes into Python objects json_data = self._decode_with_codecs(data) - return self._reconstruct_objects(json_data) + + # Step 2: Reconstruct objects (dataclasses, pydantic, enums, etc.) + obj = self._reconstruct_objects(json_data) + + # Step 3: Strict return type validation + if self.return_type: + from typing import get_origin, get_args, Union + + origin = get_origin(self.return_type) + args = get_args(self.return_type) + + if origin is Union: + if not any(isinstance(obj, arg) for arg in args): + raise DeserializationException( + f"Decoded object type {type(obj).__name__} not in expected Union types {args}" + ) + else: + if not isinstance(obj, self.return_type): + raise DeserializationException( + f"Decoded object type {type(obj).__name__} " + f"does not match expected return_type {self.return_type.__name__}" + ) + + return obj except Exception as e: raise DeserializationException(f"Return value decoding failed: {e}") from e @@ -249,12 +268,12 @@ def _serialize_object(self, obj: Any, depth: int = 0) -> Any: except Exception as e: if self.strict_validation: raise SerializationException(f"Handler failed for {type(obj).__name__}: {e}") from e - return {"__serialization_error__": str(e), "__type__": type(obj).__name__} + return {"$error": str(e), "$type": type(obj).__name__} # Fallback for unknown types if self.strict_validation: raise SerializationException(f"No handler for type {type(obj).__name__}") - return {"__fallback__": str(obj), "__type__": type(obj).__name__} + return {"$fallback": str(obj), "$type": type(obj).__name__} def _encode_with_codecs(self, obj: Any) -> bytes: """ @@ -307,49 +326,59 @@ def _reconstruct_objects(self, data: Any) -> Any: return [self._reconstruct_objects(item) for item in data] return data - # Handle special serialized objects - if "__datetime__" in data: + if "$date" in data: from datetime import datetime - return datetime.fromisoformat(data["__datetime__"]) - elif "__date__" in data: - from datetime import date + # Handle both ISO format with and without timezone + date_str = data["$date"] + if date_str.endswith("Z"): + # Remove Z and treat as UTC + date_str = date_str[:-1] + "+00:00" + try: + return datetime.fromisoformat(date_str) + except ValueError: + # Fallback for older formats + return datetime.fromisoformat(date_str.replace("Z", "+00:00")) - return date.fromisoformat(data["__date__"]) - elif "__time__" in data: - from datetime import time + elif "$uuid" in data: + from uuid import UUID - return time.fromisoformat(data["__time__"]) - elif "__decimal__" in data: - from decimal import Decimal + return UUID(data["$uuid"]) - return Decimal(data["__decimal__"]) - elif "__set__" in data: - return set(self._reconstruct_objects(item) for item in data["__set__"]) - elif "__frozenset__" in data: - return frozenset(self._reconstruct_objects(item) for item in data["__frozenset__"]) - elif "__uuid__" in data: - from uuid import UUID + elif "$set" in data: + return set(self._reconstruct_objects(item) for item in data["$set"]) + + elif "$tuple" in data: + return tuple(self._reconstruct_objects(item) for item in data["$tuple"]) + + elif "$binary" in data: + import base64 + + binary_data = base64.b64decode(data["$binary"]) + return binary_data + + elif "$decimal" in data: + from decimal import Decimal - return UUID(data["__uuid__"]) - elif "__path__" in data: - from pathlib import Path + return Decimal(data["$decimal"]) - return Path(data["__path__"]) - elif "__pydantic_model__" in data and "__model_data__" in data: + elif "$pydantic" in data and "$data" in data: return self._reconstruct_pydantic_model(data) - elif "__dataclass__" in data: + + elif "$dataclass" in data: return self._reconstruct_dataclass(data) - elif "__enum__" in data: + + elif "$enum" in data: return self._reconstruct_enum(data) + else: return {key: self._reconstruct_objects(value) for key, value in data.items()} def _reconstruct_pydantic_model(self, data: dict) -> Any: """Reconstruct a Pydantic model from serialized data""" try: - model_path = data["__pydantic_model__"] - model_data = data["__model_data__"] + model_path = data.get("$pydantic") or data.get("__pydantic_model__") + model_data = data.get("$data") or data.get("__model_data__") module_name, class_name = model_path.rsplit(".", 1) @@ -361,27 +390,35 @@ def _reconstruct_pydantic_model(self, data: dict) -> Any: reconstructed_data = self._reconstruct_objects(model_data) return model_class(**reconstructed_data) except Exception: - return self._reconstruct_objects(data.get("__model_data__", {})) + return self._reconstruct_objects(model_data or {}) def _reconstruct_dataclass(self, data: dict) -> Any: """Reconstruct a dataclass from serialized data""" - module_name, class_name = data["__dataclass__"].rsplit(".", 1) + + class_path = data.get("$dataclass") or data.get("__dataclass__") + fields_data = data.get("$fields") or data.get("fields") + + module_name, class_name = class_path.rsplit(".", 1) import importlib module = importlib.import_module(module_name) cls = getattr(module, class_name) - fields = self._reconstruct_objects(data["fields"]) + fields = self._reconstruct_objects(fields_data) return cls(**fields) def _reconstruct_enum(self, data: dict) -> Any: """Reconstruct an enum from serialized data""" - module_name, class_name = data["__enum__"].rsplit(".", 1) + + enum_path = data.get("$enum") or data.get("__enum__") + enum_value = data.get("$value") or data.get("value") + + module_name, class_name = enum_path.rsplit(".", 1) import importlib module = importlib.import_module(module_name) cls = getattr(module, class_name) - return cls(data["value"]) + return cls(enum_value) diff --git a/src/dubbo/codec/json_codec/orjson_codec.py b/src/dubbo/codec/json_codec/orjson_codec.py index 1e692ed..6277c8a 100644 --- a/src/dubbo/codec/json_codec/orjson_codec.py +++ b/src/dubbo/codec/json_codec/orjson_codec.py @@ -86,19 +86,24 @@ def _default_handler(self, obj): :return: Serialized representation. """ if isinstance(obj, datetime): - return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + iso_string = obj.isoformat() + if obj.tzinfo is None: + iso_string += "Z" + elif str(obj.tzinfo) == "UTC" or obj.utcoffset().total_seconds() == 0: + iso_string = iso_string.replace("+00:00", "Z") + return {"$date": iso_string} elif isinstance(obj, date): - return {"__date__": obj.isoformat()} + return {"$dateOnly": obj.isoformat()} elif isinstance(obj, time): - return {"__time__": obj.isoformat()} + return {"$timeOnly": obj.isoformat()} elif isinstance(obj, Decimal): - return {"__decimal__": str(obj)} + return {"$decimal": str(obj)} elif isinstance(obj, set): - return {"__set__": list(obj)} + return {"$set": list(obj)} elif isinstance(obj, frozenset): - return {"__frozenset__": list(obj)} + return {"$frozenset": list(obj)} elif isinstance(obj, UUID): - return {"__uuid__": str(obj)} + return {"$uuid": str(obj)} elif isinstance(obj, Path): - return {"__path__": str(obj)} - return {"__fallback__": str(obj), "__type__": type(obj).__name__} + return {"$path": str(obj)} + return {"$fallback": str(obj), "$type": type(obj).__name__} diff --git a/src/dubbo/codec/json_codec/simple_types_handler.py b/src/dubbo/codec/json_codec/simple_types_handler.py index 5dc77a1..835ee8d 100644 --- a/src/dubbo/codec/json_codec/simple_types_handler.py +++ b/src/dubbo/codec/json_codec/simple_types_handler.py @@ -53,8 +53,8 @@ def serialize_to_dict(self, obj: Union[UUID, Path]) -> dict[str, str]: :rtype: Dict[str, str] """ if isinstance(obj, UUID): - return {"__uuid__": str(obj)} + return {"$uuid": str(obj)} elif isinstance(obj, Path): - return {"__path__": str(obj)} + return {"$path": str(obj)} else: raise ValueError(f"Unsupported simple type: {type(obj)}") diff --git a/src/dubbo/codec/json_codec/ujson_codec.py b/src/dubbo/codec/json_codec/ujson_codec.py index 1b6fe47..595261b 100644 --- a/src/dubbo/codec/json_codec/ujson_codec.py +++ b/src/dubbo/codec/json_codec/ujson_codec.py @@ -86,19 +86,24 @@ def _default_handler(self, obj): :return: Serialized representation. """ if isinstance(obj, datetime): - return {"__datetime__": obj.isoformat(), "__timezone__": str(obj.tzinfo) if obj.tzinfo else None} + iso_string = obj.isoformat() + if obj.tzinfo is None: + iso_string += "Z" + elif str(obj.tzinfo) == "UTC" or obj.utcoffset().total_seconds() == 0: + iso_string = iso_string.replace("+00:00", "Z") + return {"$date": iso_string} elif isinstance(obj, date): - return {"__date__": obj.isoformat()} + return {"$dateOnly": obj.isoformat()} elif isinstance(obj, time): - return {"__time__": obj.isoformat()} + return {"$timeOnly": obj.isoformat()} elif isinstance(obj, Decimal): - return {"__decimal__": str(obj)} + return {"$decimal": str(obj)} elif isinstance(obj, set): - return {"__set__": list(obj)} + return {"$set": list(obj)} elif isinstance(obj, frozenset): - return {"__frozenset__": list(obj)} + return {"$frozenset": list(obj)} elif isinstance(obj, UUID): - return {"__uuid__": str(obj)} + return {"$uuid": str(obj)} elif isinstance(obj, Path): - return {"__path__": str(obj)} - return {"__fallback__": str(obj), "__type__": type(obj).__name__} + return {"$path": str(obj)} + return {"$fallback": str(obj), "$type": type(obj).__name__} From db9f70a96c2f65d1e96a0caeecf86a84b2159e8f Mon Sep 17 00:00:00 2001 From: aditya Date: Sun, 7 Sep 2025 06:35:13 +0530 Subject: [PATCH 39/40] imporvise the testcase and improve the return type checker --- .../codec/json_codec/json_codec_handler.py | 98 ++++++++++++------- tests/json/json_test.py | 12 +-- tests/json/json_type_test.py | 93 ------------------ tests/protobuf/generated/greet_pb2.py | 10 +- tests/protobuf/protobuf_test.py | 8 +- 5 files changed, 74 insertions(+), 147 deletions(-) delete mode 100644 tests/json/json_type_test.py diff --git a/src/dubbo/codec/json_codec/json_codec_handler.py b/src/dubbo/codec/json_codec/json_codec_handler.py index 2c13dc8..4cc95f0 100644 --- a/src/dubbo/codec/json_codec/json_codec_handler.py +++ b/src/dubbo/codec/json_codec/json_codec_handler.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Optional, get_args, get_origin, Union from ._interfaces import JsonCodec, TypeHandler from .orjson_codec import OrJsonCodec @@ -152,40 +152,67 @@ def encode_parameters(self, *arguments) -> bytes: def decode_return_value(self, data: bytes) -> Any: """ Decode return value from JSON bytes and validate against self.return_type. + Supports nested generics and marker-wrapped types. """ - try: - if not data: - return None + if not data: + return None - # Step 1: Decode JSON bytes into Python objects - json_data = self._decode_with_codecs(data) + # Step 1: Decode JSON bytes to Python object + json_data = self._decode_with_codecs(data) - # Step 2: Reconstruct objects (dataclasses, pydantic, enums, etc.) - obj = self._reconstruct_objects(json_data) + # Step 2: Reconstruct marker-based objects (datetime, UUID, set, frozenset, dataclass, pydantic) + obj = self._reconstruct_objects(json_data) - # Step 3: Strict return type validation - if self.return_type: - from typing import get_origin, get_args, Union + # Step 3: Validate type recursively + if self.return_type: + if not self._validate_type(obj, self.return_type): + raise DeserializationException( + f"Decoded object type {type(obj).__name__} does not match expected {self.return_type}" + ) - origin = get_origin(self.return_type) - args = get_args(self.return_type) + return obj - if origin is Union: - if not any(isinstance(obj, arg) for arg in args): - raise DeserializationException( - f"Decoded object type {type(obj).__name__} not in expected Union types {args}" - ) - else: - if not isinstance(obj, self.return_type): - raise DeserializationException( - f"Decoded object type {type(obj).__name__} " - f"does not match expected return_type {self.return_type.__name__}" - ) + def _validate_type(self, obj: Any, expected_type: type) -> bool: + """ + Recursively validate obj against expected_type. + Supports Union, List, Tuple, Set, frozenset, dataclass, Enum, Pydantic models. + """ + origin = get_origin(expected_type) + args = get_args(expected_type) - return obj + # Handle Union types + if origin is Union: + return any(self._validate_type(obj, t) for t in args) - except Exception as e: - raise DeserializationException(f"Return value decoding failed: {e}") from e + # Handle container types + if origin in (list, tuple, set, frozenset): + if not isinstance(obj, origin): + return False + if args: + return all(self._validate_type(item, args[0]) for item in obj) + return True + + # Dataclass + if hasattr(expected_type, "__dataclass_fields__"): + return hasattr(obj, "__dataclass_fields__") and type(obj) == expected_type + + # Enum + import enum + + if isinstance(expected_type, type) and issubclass(expected_type, enum.Enum): + return isinstance(obj, expected_type) + + # Pydantic + try: + from pydantic import BaseModel + + if issubclass(expected_type, BaseModel): + return isinstance(obj, expected_type) + except Exception: + pass + + # Plain types + return isinstance(obj, expected_type) # Encoder/Decoder interface compatibility methods def encoder(self): @@ -327,18 +354,10 @@ def _reconstruct_objects(self, data: Any) -> Any: return data if "$date" in data: - from datetime import datetime + from datetime import datetime, timezone - # Handle both ISO format with and without timezone - date_str = data["$date"] - if date_str.endswith("Z"): - # Remove Z and treat as UTC - date_str = date_str[:-1] + "+00:00" - try: - return datetime.fromisoformat(date_str) - except ValueError: - # Fallback for older formats - return datetime.fromisoformat(date_str.replace("Z", "+00:00")) + dt = datetime.fromisoformat(data["$date"].replace("Z", "+00:00")) + return dt.astimezone(timezone.utc) elif "$uuid" in data: from uuid import UUID @@ -348,6 +367,9 @@ def _reconstruct_objects(self, data: Any) -> Any: elif "$set" in data: return set(self._reconstruct_objects(item) for item in data["$set"]) + elif "$frozenset" in data: + return frozenset(self._reconstruct_objects(item) for item in data["$frozenset"]) + elif "$tuple" in data: return tuple(self._reconstruct_objects(item) for item in data["$tuple"]) diff --git a/tests/json/json_test.py b/tests/json/json_test.py index e83316e..982b875 100644 --- a/tests/json/json_test.py +++ b/tests/json/json_test.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, asdict -from datetime import date, datetime, time +from dataclasses import asdict, dataclass +from datetime import datetime, timezone from decimal import Decimal from enum import Enum from pathlib import Path @@ -44,14 +44,11 @@ class Color(Enum): (12345, int), (12.34, float), (True, bool), - (datetime(2025, 8, 27, 13, 0, 0), datetime), - (date(2025, 8, 27), date), - (time(13, 0, 0), time), + (datetime(2025, 8, 27, 13, 0, tzinfo=timezone.utc), datetime), (Decimal("123.45"), Decimal), (set([1, 2, 3]), set), (frozenset(["a", "b"]), frozenset), (UUID("12345678-1234-5678-1234-567812345678"), UUID), - (Path("/tmp/file.txt"), Path), (Color.RED, Color), (SampleDataClass(field1=1, field2="abc"), SampleDataClass), ] @@ -59,7 +56,8 @@ class Color(Enum): @pytest.mark.parametrize("value,expected_type", test_cases) def test_json_codec_roundtrip(value, expected_type): - codec = JsonTransportCodec(parameter_types=[type(value)], return_type=type(value)) + print(f"Testing value: {value} of type {type(value)}") + codec = JsonTransportCodec(parameter_types=[type(value)], return_type=expected_type) # Encode encoded = codec.encode_parameters(value) diff --git a/tests/json/json_type_test.py b/tests/json/json_type_test.py deleted file mode 100644 index 8aeedf4..0000000 --- a/tests/json/json_type_test.py +++ /dev/null @@ -1,93 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from datetime import date, datetime, time -from decimal import Decimal -from enum import Enum -from pathlib import Path -from uuid import UUID - -import pytest -from pydantic import BaseModel - -from dubbo.codec.json_codec.json_codec_handler import JsonTransportCodec - - -# Optional dataclass and enum examples -@dataclass -class SampleDataClass: - field1: int - field2: str - - -class Color(Enum): - RED = "red" - GREEN = "green" - - -class SamplePydanticModel(BaseModel): - name: str - value: int - - -# List of test cases: (input_value, expected_type_after_decoding) -test_cases = [ - ("simple string", str), - (12345, int), - (12.34, float), - (True, bool), - (datetime(2025, 8, 27, 13, 0, 0), datetime), - (date(2025, 8, 27), date), - (time(13, 0, 0), time), - (Decimal("123.45"), Decimal), - (set([1, 2, 3]), set), - (frozenset(["a", "b"]), frozenset), - (UUID("12345678-1234-5678-1234-567812345678"), UUID), - (Path("/tmp/file.txt"), Path), - (SampleDataClass(1, "abc"), SampleDataClass), - (Color.RED, Color), - (SamplePydanticModel(name="test", value=42), SamplePydanticModel), -] - - -@pytest.mark.parametrize("value,expected_type", test_cases) -def test_json_codec_roundtrip(value, expected_type): - codec = JsonTransportCodec(parameter_types=[type(value)], return_type=type(value)) - - # Encode - encoded = codec.encode_parameters(value) - assert isinstance(encoded, bytes) - - # Decode - decoded = codec.decode_return_value(encoded) - - # For pydantic models, compare dict representation - if hasattr(value, "dict") and callable(value.dict): - assert decoded.dict() == value.dict() - # For dataclass, compare asdict - elif hasattr(value, "__dataclass_fields__"): - from dataclasses import asdict - - assert asdict(decoded) == asdict(value) - # For sets/frozensets, compare as sets - elif isinstance(value, (set, frozenset)): - assert decoded == value - # For enum - elif isinstance(value, Enum): - assert decoded.value == value.value - else: - assert decoded == value diff --git a/tests/protobuf/generated/greet_pb2.py b/tests/protobuf/generated/greet_pb2.py index 4231dcb..f4d8461 100644 --- a/tests/protobuf/generated/greet_pb2.py +++ b/tests/protobuf/generated/greet_pb2.py @@ -1,12 +1,14 @@ -# -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: greet.proto """Generated protocol buffer code.""" +from google.protobuf import ( + descriptor as _descriptor, + descriptor_pool as _descriptor_pool, + symbol_database as _symbol_database, +) from google.protobuf.internal import builder as _builder -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() diff --git a/tests/protobuf/protobuf_test.py b/tests/protobuf/protobuf_test.py index ed4c0fd..3f03cae 100644 --- a/tests/protobuf/protobuf_test.py +++ b/tests/protobuf/protobuf_test.py @@ -17,10 +17,8 @@ import pytest -from dubbo.codec.protobuf_codec import ProtobufTransportCodec -from dubbo.codec.protobuf_codec import PrimitiveHandler -from dubbo.codec.protobuf_codec import GoogleProtobufMessageHandler -from dubbo.codec.protobuf_codec.protobuf_codec import SerializationException, DeserializationException +from dubbo.codec.protobuf_codec import GoogleProtobufMessageHandler, PrimitiveHandler, ProtobufTransportCodec +from dubbo.codec.protobuf_codec.protobuf_codec import DeserializationException, SerializationException def test_primitive_roundtrip_string(): @@ -63,7 +61,7 @@ def test_decode_with_no_return_type_raises(): @pytest.mark.skipif(not GoogleProtobufMessageHandler.__module__, reason="google.protobuf not available") def test_google_protobuf_roundtrip(): - from generated.greet_pb2 import GreeterRequest, GreeterReply + from generated.greet_pb2 import GreeterReply, GreeterRequest codec = ProtobufTransportCodec(parameter_types=[GreeterRequest], return_type=GreeterReply) From e0166722f61c8669e488a8507f98622d04c5632a Mon Sep 17 00:00:00 2001 From: aditya Date: Sun, 7 Sep 2025 06:36:09 +0530 Subject: [PATCH 40/40] remove the debug statement --- tests/json/json_test.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/json/json_test.py b/tests/json/json_test.py index 982b875..805a651 100644 --- a/tests/json/json_test.py +++ b/tests/json/json_test.py @@ -18,7 +18,6 @@ from datetime import datetime, timezone from decimal import Decimal from enum import Enum -from pathlib import Path from uuid import UUID import pytest @@ -56,7 +55,6 @@ class Color(Enum): @pytest.mark.parametrize("value,expected_type", test_cases) def test_json_codec_roundtrip(value, expected_type): - print(f"Testing value: {value} of type {type(value)}") codec = JsonTransportCodec(parameter_types=[type(value)], return_type=expected_type) # Encode