diff --git a/pyproject.toml b/pyproject.toml index 2615f138..96499e3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "click", "datasets", "deepspeed", + "eval_type_backport", "httpx[http2]", "huggingface-hub", "loguru", diff --git a/src/speculators/__init__.py b/src/speculators/__init__.py index a5f8bf28..529db76b 100644 --- a/src/speculators/__init__.py +++ b/src/speculators/__init__.py @@ -25,9 +25,8 @@ SpeculatorsConfig, TokenProposalConfig, VerifierConfig, - reload_and_populate_configs, ) -from .model import SpeculatorModel, reload_and_populate_models +from .model import SpeculatorModel __all__ = [ "SpeculatorModel", @@ -35,10 +34,4 @@ "SpeculatorsConfig", "TokenProposalConfig", "VerifierConfig", - "reload_and_populate_configs", - "reload_and_populate_models", ] - -# base imports complete, run auto loading for base classes -reload_and_populate_configs() -reload_and_populate_models() diff --git a/src/speculators/config.py b/src/speculators/config.py index 36acea4c..adf6fc55 100644 --- a/src/speculators/config.py +++ b/src/speculators/config.py @@ -32,7 +32,6 @@ "SpeculatorsConfig", "TokenProposalConfig", "VerifierConfig", - "reload_and_populate_configs", ] @@ -50,11 +49,8 @@ class TokenProposalConfig(PydanticClassRegistryMixin): """ @classmethod - def __pydantic_schema_base_type__(cls) -> type["TokenProposalConfig"]: - if cls.__name__ == "TokenProposalConfig": - return cls - - return TokenProposalConfig + def __pydantic_schema_base_name__(cls) -> str: + return "TokenProposalConfig" auto_package: ClassVar[str] = "speculators.proposals" registry_auto_discovery: ClassVar[bool] = True @@ -238,11 +234,8 @@ def from_dict( return cls.model_validate(dict_obj) @classmethod - def __pydantic_schema_base_type__(cls) -> type["SpeculatorModelConfig"]: - if cls.__name__ == "SpeculatorModelConfig": - return cls - - return SpeculatorModelConfig + def __pydantic_schema_base_name__(cls) -> str: + return "SpeculatorModelConfig" # Pydantic configuration model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @@ -328,14 +321,3 @@ def to_diff_dict(self) -> dict[str, Any]: or set, along with all Pydantic fields. """ return super().to_diff_dict() - - -def reload_and_populate_configs(): - """ - Automatically populates the registry for all PydanticClassRegistryMixin subclasses - and reloads schemas for all Config classes to ensure their schemas are up-to-date - with the current registry state. - """ - TokenProposalConfig.auto_populate_registry() - SpeculatorsConfig.reload_schema() - SpeculatorModelConfig.auto_populate_registry() diff --git a/src/speculators/convert/eagle/__init__.py b/src/speculators/convert/eagle/__init__.py index 64777b87..9d007b3c 100644 --- a/src/speculators/convert/eagle/__init__.py +++ b/src/speculators/convert/eagle/__init__.py @@ -1,7 +1,8 @@ """ -Eagle checkpoint conversion utilities. +EAGLE v1, EAGLE v2, EAGLE v3, and HASS checkpoint conversion utilities. """ +from speculators.convert.eagle.eagle3_converter import Eagle3Converter from speculators.convert.eagle.eagle_converter import EagleConverter -__all__ = ["EagleConverter"] +__all__ = ["Eagle3Converter", "EagleConverter"] diff --git a/src/speculators/convert/eagle/eagle3_converter.py b/src/speculators/convert/eagle/eagle3_converter.py index 778df0f8..0ba1d72c 100644 --- a/src/speculators/convert/eagle/eagle3_converter.py +++ b/src/speculators/convert/eagle/eagle3_converter.py @@ -18,6 +18,8 @@ from speculators.models.eagle3 import Eagle3Speculator, Eagle3SpeculatorConfig from speculators.proposals.greedy import GreedyTokenProposalConfig +__all__ = ["Eagle3Converter"] + class Eagle3Converter: """ diff --git a/src/speculators/convert/eagle/eagle_converter.py b/src/speculators/convert/eagle/eagle_converter.py index 3d2e4a19..8bf3d377 100644 --- a/src/speculators/convert/eagle/eagle_converter.py +++ b/src/speculators/convert/eagle/eagle_converter.py @@ -19,6 +19,8 @@ from speculators.models.eagle import EagleSpeculator, EagleSpeculatorConfig from speculators.proposals.greedy import GreedyTokenProposalConfig +__all__ = ["EagleConverter"] + class EagleConverter: """ diff --git a/src/speculators/model.py b/src/speculators/model.py index d8f6bca2..3450c473 100644 --- a/src/speculators/model.py +++ b/src/speculators/model.py @@ -14,10 +14,6 @@ Classes: SpeculatorModel: Abstract base class for all speculator models with transformers compatibility, automatic registry support, and speculative generation methods - -Functions: - reload_and_populate_models: Automatically populates the model registry for - discovery and instantiation of registered speculator models """ import os @@ -37,10 +33,12 @@ from transformers.generation.utils import GenerateOutput from speculators.config import SpeculatorModelConfig -from speculators.utils import ClassRegistryMixin +from speculators.utils import RegistryMixin -class SpeculatorModel(ClassRegistryMixin, PreTrainedModel, GenerationMixin): # type: ignore[misc] +class SpeculatorModel( # type: ignore[misc] + RegistryMixin[type["SpeculatorModel"]], PreTrainedModel, GenerationMixin +): """ Abstract base class for all speculator models in the Speculators library. @@ -559,14 +557,3 @@ def generate( raise NotImplementedError( "The generate method for speculator models is not implemented yet." ) - - -def reload_and_populate_models(): - """ - Triggers the automatic discovery and registration of all - SpeculatorModel subclasses found in the speculators.models package - that have been registered with `SpeculatorModel.register(NAME)`. This - enables dynamic model loading and instantiation based on configuration - types without requiring explicit imports. - """ - SpeculatorModel.auto_populate_registry() diff --git a/src/speculators/utils/__init__.py b/src/speculators/utils/__init__.py index ebe8d140..edf0889b 100644 --- a/src/speculators/utils/__init__.py +++ b/src/speculators/utils/__init__.py @@ -1,10 +1,10 @@ from .auto_importer import AutoImporterMixin from .pydantic_utils import PydanticClassRegistryMixin, ReloadableBaseModel -from .registry import ClassRegistryMixin +from .registry import RegistryMixin __all__ = [ "AutoImporterMixin", - "ClassRegistryMixin", "PydanticClassRegistryMixin", + "RegistryMixin", "ReloadableBaseModel", ] diff --git a/src/speculators/utils/auto_importer.py b/src/speculators/utils/auto_importer.py index 3b3240d3..254c2bd6 100644 --- a/src/speculators/utils/auto_importer.py +++ b/src/speculators/utils/auto_importer.py @@ -1,64 +1,56 @@ """ Automatic module importing utilities for dynamic class discovery. -This module provides a mixin class for automatic module importing within a package, +This module provides a mixin class for automatic module importing within packages, enabling dynamic discovery of classes and implementations without explicit imports. -It is particularly useful for auto-registering classes in a registry pattern where -subclasses need to be discoverable at runtime. - -The AutoImporterMixin can be combined with registration mechanisms to create -extensible systems where new implementations are automatically discovered and -registered when they are placed in the correct package structure. - -Classes: - - AutoImporterMixin: A mixin class that provides functionality to automatically - import all modules within a specified package or list of packa +It is designed for registry patterns where subclasses need to be discoverable at +runtime, creating extensible systems where new implementations are automatically +discovered when placed in the correct package structure. """ import importlib import pkgutil import sys -from typing import ClassVar, Optional, Union +from typing import ClassVar, Union __all__ = ["AutoImporterMixin"] class AutoImporterMixin: """ - A mixin class that provides functionality to automatically import all modules - within a specified package or list of packages. - - This mixin is designed to be used with class registration mechanisms to enable - automatic discovery and registration of classes without explicit imports. When - a class inherits from AutoImporterMixin, it can define the package(s) to scan - for modules by setting the `auto_package` class variable. - - Usage Example: - ```python - from speculators.utils import AutoImporterMixin - class MyRegistry(AutoImporterMixin): - auto_package = "my_package.implementations" - - MyRegistry.auto_import_package_modules() - ``` - - :cvar auto_package: The package name or tuple of names to import modules from. - :cvar auto_ignore_modules: Optional tuple of module names to ignore during import. - :cvar auto_imported_modules: List tracking which modules have been imported. + Mixin class for automatic module importing within packages. + + This mixin enables dynamic discovery of classes and implementations by + automatically importing all modules within specified packages. It is designed + for use with class registration mechanisms to enable automatic discovery and + registration of classes when they are placed in the correct package structure. + + Example: + :: + from speculators.utils import AutoImporterMixin + + class MyRegistry(AutoImporterMixin): + auto_package = "my_package.implementations" + + MyRegistry.auto_import_package_modules() + + :cvar auto_package: Package name or tuple of package names to import modules from + :cvar auto_ignore_modules: Module names to ignore during import + :cvar auto_imported_modules: List tracking which modules have been imported """ - auto_package: ClassVar[Optional[Union[str, tuple[str, ...]]]] = None - auto_ignore_modules: ClassVar[Optional[tuple[str, ...]]] = None - auto_imported_modules: ClassVar[Optional[list]] = None + auto_package: ClassVar[Union[str, tuple[str, ...], None]] = None + auto_ignore_modules: ClassVar[Union[tuple[str, ...], None]] = None + auto_imported_modules: ClassVar[Union[list[str], None]] = None @classmethod - def auto_import_package_modules(cls): + def auto_import_package_modules(cls) -> None: """ - Automatically imports all modules within the specified package(s). + Automatically import all modules within the specified package(s). - This method scans the package(s) defined in the `auto_package` class variable - and imports all modules found, tracking them in `auto_imported_modules`. It - skips packages (directories) and any modules listed in `auto_ignore_modules`. + Scans the package(s) defined in `auto_package` and imports all modules found, + tracking them in `auto_imported_modules`. Skips packages and any modules + listed in `auto_ignore_modules`. :raises ValueError: If the `auto_package` class variable is not set """ diff --git a/src/speculators/utils/pydantic_utils.py b/src/speculators/utils/pydantic_utils.py index 01816157..74aed552 100644 --- a/src/speculators/utils/pydantic_utils.py +++ b/src/speculators/utils/pydantic_utils.py @@ -1,113 +1,236 @@ """ -General pydantic utilities for Speculators. +Pydantic utilities for polymorphic model serialization and registry integration. -This module provides integration between Pydantic and the Speculators library, -enabling things like polymorphic serialization and deserialization of Pydantic -models using a discriminator field and registry. - -Classes: - PydanticClassRegistryMixin: A mixin that combines Pydantic models with the - ClassRegistryMixin to support polymorphic model instantiation based on - a discriminator field +Provides integration between Pydantic and the registry system, enabling +polymorphic serialization and deserialization of Pydantic models using +a discriminator field and dynamic class registry. Includes base model classes +with standardized configurations and generic status breakdown models for +structured result organization. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin from pydantic import BaseModel, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema -from speculators.utils.registry import ClassRegistryMixin +from speculators.utils.registry import RegistryMixin + +__all__ = [ + "BaseModelT", + "PydanticClassRegistryMixin", + "RegisterClassT", + "ReloadableBaseModel", +] -__all__ = ["PydanticClassRegistryMixin", "ReloadableBaseModel"] + +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) +RegisterClassT = TypeVar("RegisterClassT", bound=type[BaseModel]) class ReloadableBaseModel(BaseModel): + """ + Base Pydantic model with schema reloading capabilities. + + Provides dynamic schema rebuilding functionality for models that need to + update their validation schemas at runtime, particularly useful when + working with registry-based polymorphic models where new types are + registered after initial class definition. + """ + @classmethod - def reload_schema(cls): + def reload_schema(cls, dependencies: bool = True): """ - Reloads the schema for the class, ensuring that the registry is populated - and that the schema is up-to-date. + Reload and rebuild the Pydantic model validation schema. + + Forces reconstruction of the model schema and optionally rebuilds + schemas for all dependent models in the reloadable dependency chains. + Essential when new types are registered that affect polymorphic validation. - This method is useful when the registry has been modified or when the - class needs to be re-validated with the latest schema. + :param dependencies: Whether to reload dependent model schemas as well """ cls.model_rebuild(force=True) + if dependencies: + for chain in cls.reloadable_dependency_chains(): + for clazz in chain: + clazz.model_rebuild(force=True) + + @classmethod + def reloadable_dependency_chains( + cls, target: type[ReloadableBaseModel] | None = None + ) -> list[list[type[ReloadableBaseModel]]]: + """ + Find all dependency chains leading to the target model class. + + Uses depth-first search to identify dependency paths between reloadable + models, ensuring proper schema reload ordering to maintain validation + consistency across the polymorphic model hierarchy. + + :param target: Target model class to find chains for. Uses cls if None + :return: List of dependency chains ending at the target class + """ + if target is None: + target = cls + + # Build a map of all reloadable classes to their dependencies + dependencies: dict[ + type[ReloadableBaseModel], list[type[ReloadableBaseModel]] + ] = {} + + for reloadable in cls.reloadable_descendants(ReloadableBaseModel): + deps = [] + for field_deps in reloadable.reloadable_fields().values(): + deps.extend(field_deps) + dependencies[reloadable] = deps + + # Find all dependency chains ending at target using DFS + chains = [] + + def _find_chains( + current: type[ReloadableBaseModel], path: list[type[ReloadableBaseModel]] + ): + if current == target: + chains.append(path) + return + + for dependent in dependencies.get(current, []): + if dependent not in path: # Avoid cycles + _find_chains(dependent, [current] + path) + + for cls_type, deps in dependencies.items(): + if deps and cls_type != target: + _find_chains(cls_type, []) + + return chains + + @classmethod + def reloadable_fields( + cls, + ) -> dict[str, list[type[ReloadableBaseModel]]]: + """ + Identify model fields containing reloadable model types. + + Recursively analyzes field type annotations to find all ReloadableBaseModel + subclasses used within the model schema, enabling dependency tracking for + proper schema reload ordering. + + :return: Mapping of field names to lists of reloadable model types + """ + + def _recursive_resolve_reloadable_types(type_: type | None) -> list[type]: + if type_ is None: + return [] + + if (origin := get_origin(type_)) is None: + return [type_] if issubclass(type_, ReloadableBaseModel) else [] + + resolved = [] + if issubclass(origin, ReloadableBaseModel): + resolved.append(origin) + + for arg in get_args(type_): + resolved.extend(_recursive_resolve_reloadable_types(arg)) + + return resolved + + fields = {} -class PydanticClassRegistryMixin(ReloadableBaseModel, ABC, ClassRegistryMixin): + for name, info in cls.model_fields.items(): + if reloadable_types := _recursive_resolve_reloadable_types(info.annotation): + fields[name] = reloadable_types + + return fields + + @classmethod + def reloadable_descendants( + cls, target: type[ReloadableBaseModel] | None = None + ) -> set[type[ReloadableBaseModel]]: + """ + Find all ReloadableBaseModel descendants of the target class. + + Traverses the inheritance hierarchy to collect all subclasses that inherit + from ReloadableBaseModel, enabling comprehensive dependency analysis for + schema reloading operations. + + :param target: Base class to find descendants for. Uses cls if None + :return: Set of all descendant ReloadableBaseModel classes + """ + if target is None: + target = cls + + descendants: set[type[ReloadableBaseModel]] = set() + stack: list[type[ReloadableBaseModel]] = [target] + + while stack: + current = stack.pop() + for subclass in current.__subclasses__(): + if ( + issubclass(subclass, ReloadableBaseModel) + and subclass is not cls + and subclass not in descendants + ): + descendants.add(subclass) + stack.append(subclass) + + return descendants + + +class PydanticClassRegistryMixin( + ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] +): """ - A mixin class that integrates Pydantic models with the ClassRegistryMixin to enable - polymorphic serialization and deserialization based on a discriminator field. - - This mixin allows Pydantic models to be registered in a registry and dynamically - instantiated based on a discriminator field in the input data. - It overrides Pydantic's validation system to correctly instantiate the appropriate - subclass based on the discriminator value and the name of the registered classes. - - The mixin is particularly useful for implementing base registry classes that need to - support multiple implementations, such as different token proposal methods or - speculative decoding algorithms. - - Usage Example: - ```python - from typing import ClassVar - from pydantic import BaseModel, Field - from speculators.utils import PydanticClassRegistryMixin - - class BaseConfig(PydanticClassRegistryMixin): - @classmethod - def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: - if cls.__name__ == "BaseConfig": - return cls - return BaseConfig - - schema_discriminator: ClassVar[str] = "config_type" - config_type: str = Field(description="The type of configuration") - - @BaseConfig.register("config_a") - class ConfigA(BaseConfig): - config_type: str = "config_a" - value_a: str = Field(description="A value specific to ConfigA") - - @BaseConfig.register("config_b") - class ConfigB(BaseConfig): - config_type: str = "config_b" - value_b: int = Field(description="A value specific to ConfigB") - - BaseConfig.reload_schema() # Ensures the schema is up-to-date with registry - - # Dynamic instantiation based on config_type - config_data = {"config_type": "config_a", "value_a": "test"} - config = BaseConfig.model_validate(config_data) # Returns ConfigA instance - print(config) - dump_data = config.model_dump() # Dumps the data to a dictionary - print(dump_data) # Output: {'config_type': 'config_a', 'value_a': 'test'} - ``` - - :cvar schema_discriminator: The field name used as the discriminator in the JSON - schema. Default is "model_type". - :cvar registry: A dictionary mapping discriminator values to pydantic model classes. + Polymorphic Pydantic model enabling registry-based dynamic type instantiation. + + Integrates Pydantic validation with the registry system for polymorphic + serialization and deserialization using a discriminator field. Automatically + instantiates the correct subclass during validation based on registry mappings. + + Example: + :: + from speculators.utils import PydanticClassRegistryMixin + + class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): + schema_discriminator: ClassVar[str] = "config_type" + config_type: str = Field(description="Configuration type identifier") + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "BaseConfig" + + @BaseConfig.register("database") + class DatabaseConfig(BaseConfig): + config_type: str = "database" + connection_string: str = Field(description="Database connection string") + + # Dynamic instantiation based on discriminator + config = BaseConfig.model_validate({ + "config_type": "database", + "connection_string": "postgresql://localhost:5432/db" + }) + + :cvar schema_discriminator: Field name for polymorphic type discrimination """ schema_discriminator: ClassVar[str] = "model_type" - registry: ClassVar[Optional[dict[str, BaseModel]]] = None # type: ignore[assignment] @classmethod - def register_decorator( - cls, clazz: type[BaseModel], name: Optional[str] = None - ) -> type[BaseModel]: + def register_decorator( # type: ignore[override] + cls, clazz: RegisterClassT, name: str | list[str] | None = None + ) -> RegisterClassT: """ - Registers a Pydantic model class with the registry. + Register a Pydantic model class with type validation and schema reload. - This method extends the ClassRegistryMixin.register_decorator method by adding - a type check to ensure only Pydantic BaseModel subclasses can be registered. + Validates that the class is a proper Pydantic BaseModel subclass before + registering it in the class registry. Automatically triggers schema + reload to incorporate the new type into polymorphic validation. - :param clazz: The Pydantic model class to register - :param name: Optional name to register the class under. If None, the class name - is used as the registry key. - :return: The registered class. - :raises TypeError: If clazz is not a subclass of Pydantic BaseModel + :param clazz: Pydantic model class to register in the polymorphic hierarchy + :param name: Registry identifier for the class. Uses class name if None + :return: The registered class unchanged for decorator chaining + :raises TypeError: If clazz is not a Pydantic BaseModel subclass """ if not issubclass(clazz, BaseModel): raise TypeError( @@ -115,58 +238,55 @@ def register_decorator( "Pydantic BaseModel" ) - return super().register_decorator(clazz, name=name) + super().register_decorator(clazz, name=name) + cls.reload_schema() + + return clazz @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: """ - Customizes the Pydantic schema for polymorphic model validation. + Generate polymorphic validation schema for dynamic type instantiation. - This method is part of Pydantic's validation system and is called during - schema generation. It checks if the source_type matches the base type of the - polymorphic model. If it does, it generates a tagged union schema that allows - for dynamic instantiation of the appropriate subclass based on the discriminator - field. + Creates a tagged union schema that enables Pydantic to automatically + instantiate the correct subclass based on the discriminator field value. + Falls back to base schema generation when no registry is available. - :param source_type: The type for which the schema is being generated - :param handler: Handler for generating core schema - :return: A CoreSchema object with the custom validator if appropriate + :param source_type: Type being processed for schema generation + :param handler: Pydantic core schema generation handler + :return: Tagged union schema for polymorphic validation or base schema """ - if source_type == cls.__pydantic_schema_base_type__(): - if not cls.registry: - return cls.__pydantic_generate_base_schema__(handler) + if ( + source_type is None + or not isinstance(source_type, type) + or source_type.__name__ != cls.__pydantic_schema_base_name__() + ): + return handler(cls) - choices = { - name: handler(model_class) for name, model_class in cls.registry.items() - } + if not cls.registry: + return cls.__pydantic_generate_base_schema__(handler) - return core_schema.tagged_union_schema( - choices=choices, - discriminator=cls.schema_discriminator, - ) + choices = { + name: handler(model_class) for name, model_class in cls.registry.items() + } - return handler(cls) + return core_schema.tagged_union_schema( + choices=choices, + discriminator=cls.schema_discriminator, + ) @classmethod @abstractmethod - def __pydantic_schema_base_type__(cls) -> type[Any]: + def __pydantic_schema_base_name__(cls) -> str: """ - Abstract method that must be implemented by subclasses to define the base type. + Define the name of the base type for polymorphic validation hierarchy. - This method should return the base class type that serves as the root of the - polymorphic hierarchy. The returned type is used to determine when to apply - the custom validation logic for polymorphic instantiation. + Must be implemented by subclasses to specify which type serves as the + root of the polymorphic hierarchy for schema generation and validation. - Example implementation: - ```python - @classmethod - def __pydantic_schema_base_type__(cls) -> type["MyBaseClass"]: - return MyBaseClass - ``` - - :return: The base class type for polymorphic validation + :return: Base class name for the polymorphic model hierarchy """ ... @@ -175,37 +295,52 @@ def __pydantic_generate_base_schema__( cls, handler: GetCoreSchemaHandler ) -> CoreSchema: """ - Generates the base schema for the polymorphic model. - - This method is used by the Pydantic validation system to create the core - schema for the base class. By default, it returns an any_schema which accepts - any valid input, relying on the validator function to perform the actual - validation and model instantiation. + Generate fallback schema for polymorphic models without registry. - Subclasses can override this method to provide a more specific base schema - if needed. + Provides a base schema that accepts any valid input when no registry + is available for polymorphic validation. Used as fallback during + schema generation when the registry has not been populated. - :param handler: Handler for generating core schema - :return: A CoreSchema object representing the base schema + :param handler: Pydantic core schema generation handler + :return: Base CoreSchema that accepts any valid input """ return core_schema.any_schema() @classmethod def auto_populate_registry(cls) -> bool: """ - Ensures that all registered classes in the registry are properly initialized. + Initialize registry with auto-discovery and reload validation schema. - This method is called automatically by Pydantic when the model is instantiated - or validated. It ensures that all classes in the registry are loaded and ready - for use. + Triggers automatic population of the class registry through the parent + RegistryMixin functionality and ensures the Pydantic validation schema + is updated to include all discovered types for polymorphic validation. - This is particularly useful for ensuring that all subclasses are registered - before any validation occurs. - - :return: True if the registry was populated, False if it was already populated - :raises ValueError: If called when registry_auto_discovery is False + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is disabled """ populated = super().auto_populate_registry() cls.reload_schema() return populated + + @classmethod + def registered_classes(cls) -> tuple[type[BaseModelT], ...]: + """ + Get all registered pydantic classes from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered classes including auto-discovered ones + :raises ValueError: If called before any objects have been registered + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "ClassRegistryMixin.registered_classes() must be called after " + "registering classes with ClassRegistryMixin.register()." + ) + + return tuple(cls.registry.values()) diff --git a/src/speculators/utils/registry.py b/src/speculators/utils/registry.py index 21994b91..05448591 100644 --- a/src/speculators/utils/registry.py +++ b/src/speculators/utils/registry.py @@ -1,189 +1,149 @@ """ -Registry system for classes in the Speculators library. +Registry system for dynamic object registration and discovery. -This module provides a flexible class registration and discovery system used -throughout the Speculators library. It enables automatic registration of classes -and discovery of implementations through class decorators and module imports. - -The registry system is used to track different implementations of token proposal -methods, speculative decoding algorithms, and speculator models, allowing for -dynamic discovery and instantiation based on configuration parameters. - -Classes: - ClassRegistryMixin: Base mixin for creating class registries with decorators - and optional auto-discovery capabilities through registry_auto_discovery flag. - AutoClassRegistryMixin: A backward-compatible version of ClassRegistryMixin with - auto-discovery enabled by default +Provides a flexible object registration system with optional auto-discovery +capabilities through decorators and module imports. Enables dynamic discovery +and instantiation of implementations based on configuration parameters, supporting +both manual registration and automatic package-based discovery for extensible +plugin architectures. """ -from typing import Any, Callable, ClassVar, Optional +from __future__ import annotations + +from typing import Callable, ClassVar, Generic, TypeVar, Union, cast from speculators.utils.auto_importer import AutoImporterMixin -__all__ = ["ClassRegistryMixin"] +__all__ = ["RegisterT", "RegistryMixin", "RegistryObjT"] -class ClassRegistryMixin(AutoImporterMixin): - """ - A mixin class that provides a registration system for tracking class - implementations with optional auto-discovery capabilities. +RegistryObjT = TypeVar("RegistryObjT") +"""Generic type variable for objects managed by the registry system.""" +RegisterT = TypeVar("RegisterT") +"""Generic type variable for the args and return values within the registry.""" - This mixin allows classes to maintain a registry of subclasses that can be - dynamically discovered and instantiated. Classes that inherit from this mixin - can use the @register decorator to add themselves to the registry. - The registry is class-specific, meaning each class that inherits from this mixin - will have its own separate registry of implementations. +class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin): + """ + Generic mixin for creating object registries with optional auto-discovery. - The mixin can also be configured to automatically discover and register classes - from specified packages by setting registry_auto_discovery=True and defining - an auto_package class variable to specify which package(s) should be automatically - imported to discover implementations. + Enables classes to maintain separate registries of objects that can be dynamically + discovered and instantiated through decorators and module imports. Supports both + manual registration via decorators and automatic discovery through package scanning + for extensible plugin architectures. Example: - ```python - class BaseAlgorithm(ClassRegistryMixin): - pass + :: + class BaseAlgorithm(RegistryMixin): + pass - @BaseAlgorithm.register() - class ConcreteAlgorithm(BaseAlgorithm): - pass + @BaseAlgorithm.register() + class ConcreteAlgorithm(BaseAlgorithm): + pass - @BaseAlgorithm.register("custom_name") - class AnotherAlgorithm(BaseAlgorithm): - pass + @BaseAlgorithm.register("custom_name") + class AnotherAlgorithm(BaseAlgorithm): + pass - # Get all registered algorithm implementations - algorithms = BaseAlgorithm.registered_classes() - ``` + # Get all registered implementations + algorithms = BaseAlgorithm.registered_objects() Example with auto-discovery: - ```python - class TokenProposal(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "speculators.proposals" - - # This will automatically import all modules in the proposals package - # and register any classes decorated with @TokenProposal.register() - proposals = TokenProposal.registered_classes() - ``` - - :cvar registry: A dictionary mapping class names to classes that have been - registered to the extending subclass through the @subclass.register() decorator - :cvar registry_auto_discovery: A flag that enables automatic discovery and import of - modules from the auto_package when set to True. Default is False. - :cvar registry_populated: A flag that tracks whether the registry has been - populated with classes from the specified package(s). + :: + class TokenProposal(RegistryMixin): + registry_auto_discovery = True + auto_package = "mypackage.proposals" + + # Automatically imports and registers decorated objects + proposals = TokenProposal.registered_objects() + + :cvar registry: Dictionary mapping names to registered objects + :cvar registry_auto_discovery: Enable automatic package-based discovery + :cvar registry_populated: Track whether auto-discovery has completed """ - registry: ClassVar[Optional[dict[str, type[Any]]]] = None + registry: ClassVar[Union[dict[str, RegistryObjT], None]] = None # type: ignore[misc] # noqa: UP007 registry_auto_discovery: ClassVar[bool] = False registry_populated: ClassVar[bool] = False @classmethod - def register(cls, name: Optional[str] = None) -> Callable[[type[Any]], type[Any]]: + def register( + cls, name: str | list[str] | None = None + ) -> Callable[[RegisterT], RegisterT]: """ - An invoked class decorator that registers that class with the registry under - either the provided name or the class name if no name is provided. - - Example: - ```python - @ClassRegistryMixin.register() - class ExampleClass: - ... - - @ClassRegistryMixin.register("custom_name") - class AnotherExampleClass: - ... - ``` - - :param name: Optional name to register the class under. If None, the class name - is used as the registry key. - :return: A decorator function that registers the decorated class. - :raises ValueError: If name is provided but is not a string. + Decorator for registering objects with the registry. + + :param name: Optional name(s) to register the object under. + If None, uses the object's __name__ attribute + :return: Decorator function that registers the decorated object + :raises ValueError: If name is not a string, list of strings, or None """ - if name is not None and not isinstance(name, str): - raise ValueError( - "ClassRegistryMixin.register() name must be a string or None. " - f"Got {name}." - ) - return lambda subclass: cls.register_decorator(subclass, name=name) + def _decorator(obj: RegisterT) -> RegisterT: + cls.register_decorator(obj, name=name) + return obj + + return _decorator @classmethod def register_decorator( - cls, clazz: type[Any], name: Optional[str] = None - ) -> type[Any]: - """ - A non-invoked class decorator that registers the class with the registry. - If passed through a lambda, then name can be passed in as well. - Otherwise, the only argument is the decorated class. - - Example: - ```python - @ClassRegistryMixin.register_decorator - class ExampleClass: - ... - ``` - - :param clazz: The class to register - :param name: Optional name to register the class under. If None, the class name - is used as the registry key. - :return: The registered class. - :raises TypeError: If the decorator is used incorrectly or if the class is not - a type. - :raises ValueError: If the class is already registered or if name is provided - but is not a string. + cls, obj: RegisterT, name: str | list[str] | None = None + ) -> RegisterT: """ + Register an object directly with the registry. - if not isinstance(clazz, type): - raise TypeError( - "ClassRegistryMixin.register_decorator must be used as a class " - "decorator and without invocation." - f"Got improper clazz arg {clazz}." - ) + :param obj: The object to register + :param name: Optional name(s) to register the object under. + If None, uses the object's __name__ attribute + :return: The registered object + :raises ValueError: If the object is already registered or name is invalid + """ - if not name: - name = clazz.__name__ - elif not isinstance(name, str): + if name is None: + name = obj.__name__ if hasattr(obj, "__name__") else str(obj) + elif not isinstance(name, (str, list)): raise ValueError( - "ClassRegistryMixin.register_decorator must be used as a class " - "decorator and without invocation. " - f"Got imporoper name arg {name}." + "RegistryMixin.register_decorator name must be a string or " + f"an iterable of strings. Got {name}." ) if cls.registry is None: cls.registry = {} - if name in cls.registry: - raise ValueError( - f"ClassRegistryMixin.register_decorator cannot register a class " - f"{clazz} with the name {name} because it is already registered." - ) + names = [name] if isinstance(name, str) else list(name) + + for register_name in names: + if not isinstance(register_name, str): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"a list of strings. Got {register_name}." + ) + + if register_name in cls.registry: + raise ValueError( + f"RegistryMixin.register_decorator cannot register an object " + f"{obj} with the name {register_name} because it is already " + "registered." + ) - cls.registry[name] = clazz + cls.registry[register_name] = cast("RegistryObjT", obj) - return clazz + return obj @classmethod def auto_populate_registry(cls) -> bool: """ - Ensures that all modules in the specified auto_package are imported. + Import and register all modules from the auto_package. - This method is called automatically by registered_classes when - registry_auto_discovery==True to ensure that all available implementations are - discovered and registered before returning the list of registered classes. + Automatically called by registered_objects when registry_auto_discovery is True + to ensure all available implementations are discovered. - To enable auto-discovery: - 1. Set registry_auto_discovery = True on the class - 2. Define an auto_package class variable with the package path to import - - :return: True if the registry was populated, False if it was already populated. + :return: True if registry was populated, False if already populated :raises ValueError: If called when registry_auto_discovery is False """ if not cls.registry_auto_discovery: raise ValueError( - "ClassRegistryMixin.auto_populate_registry() cannot be called " + "RegistryMixin.auto_populate_registry() cannot be called " "because registry_auto_discovery is set to False. " "Set registry_auto_discovery to True to enable auto-discovery." ) @@ -197,26 +157,62 @@ def auto_populate_registry(cls) -> bool: return True @classmethod - def registered_classes(cls) -> tuple[type[Any], ...]: + def registered_objects(cls) -> tuple[RegistryObjT, ...]: """ - Returns a tuple of all classes that have been registered with this registry. + Get all registered objects from the registry. - If registry_auto_discovery is True, this method will first call - auto_populate_registry to ensure that all available implementations from - the specified auto_package are discovered and registered before returning - the list. + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. - :return: A tuple containing all registered class implementations, including - those discovered through auto-importing when registry_auto_discovery==True. - :raises ValueError: If called before any classes have been registered. + :return: Tuple of all registered objects including auto-discovered ones + :raises ValueError: If called before any objects have been registered """ if cls.registry_auto_discovery: cls.auto_populate_registry() if cls.registry is None: raise ValueError( - "ClassRegistryMixin.registered_classes() must be called after " - "registering classes with ClassRegistryMixin.register()." + "RegistryMixin.registered_objects() must be called after " + "registering objects with RegistryMixin.register()." ) return tuple(cls.registry.values()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if an object is registered under the given name. + It matches first by exact name, then by str.lower(). + + :param name: The name to check for registration. + :return: True if the object is registered, False otherwise. + """ + if cls.registry is None: + return False + + return name in cls.registry or name.lower() in [ + key.lower() for key in cls.registry + ] + + @classmethod + def get_registered_object(cls, name: str) -> RegistryObjT | None: + """ + Get a registered object by its name. It matches first by exact name, + then by str.lower(). + + :param name: The name of the registered object. + :return: The registered object if found, None otherwise. + """ + if cls.registry is None: + return None + + if name in cls.registry: + return cls.registry[name] + + lower_key_map = {key.lower(): key for key in cls.registry} + + return ( + cls.registry[lower_key_map[name.lower()]] + if name.lower() in lower_key_map + else None + ) diff --git a/tests/unit/models/test_eagle_config.py b/tests/unit/models/test_eagle_config.py index f89fd7ec..745dfc70 100644 --- a/tests/unit/models/test_eagle_config.py +++ b/tests/unit/models/test_eagle_config.py @@ -324,9 +324,9 @@ def test_eagle_speculator_config_auto_registry(): assert "EagleSpeculatorConfig" in class_names # Verify registry key mapping - assert SpeculatorModelConfig.registry is not None - assert "eagle" in SpeculatorModelConfig.registry - assert SpeculatorModelConfig.registry["eagle"] == EagleSpeculatorConfig + assert SpeculatorModelConfig.registry is not None # type: ignore[misc] + assert "eagle" in SpeculatorModelConfig.registry # type: ignore[misc] + assert SpeculatorModelConfig.registry["eagle"] == EagleSpeculatorConfig # type: ignore[misc] @pytest.mark.smoke diff --git a/tests/unit/models/test_eagle_model.py b/tests/unit/models/test_eagle_model.py index 4a1f73d0..afde0440 100644 --- a/tests/unit/models/test_eagle_model.py +++ b/tests/unit/models/test_eagle_model.py @@ -215,9 +215,9 @@ def test_eagle_speculator_class_attributes(): @pytest.mark.smoke def test_eagle_speculator_registry(): - assert SpeculatorModel.registry is not None - assert "eagle" in SpeculatorModel.registry - assert SpeculatorModel.registry["eagle"] == EagleSpeculator + assert SpeculatorModel.registry is not None # type: ignore[misc] + assert "eagle" in SpeculatorModel.registry # type: ignore[misc] + assert SpeculatorModel.registry["eagle"] == EagleSpeculator # type: ignore[misc] @pytest.mark.smoke diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index dbe026e9..fd95da39 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -17,7 +17,6 @@ SpeculatorsConfig, TokenProposalConfig, VerifierConfig, - reload_and_populate_configs, ) # ===== TokenProposalConfig Tests ===== @@ -29,10 +28,6 @@ class TokenProposalConfigTest(TokenProposalConfig): test_field: int = 123 -# Ensure the schemas are reloaded to include the test proposal type -reload_and_populate_configs() - - @pytest.mark.smoke def test_token_proposal_config_initialization(): config: TokenProposalConfigTest = TokenProposalConfig( # type: ignore[assignment] @@ -239,10 +234,6 @@ class SpeculatorModelConfigTest(SpeculatorModelConfig): test_field: int = 456 -# Ensure the schemas are reloaded to include the test proposal type -reload_and_populate_configs() - - @pytest.fixture def sample_speculators_config(sample_token_proposal_config, sample_verifier_config): return SpeculatorsConfig( diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 7fb8f441..a222cefc 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -17,8 +17,6 @@ SpeculatorModelConfig, SpeculatorsConfig, VerifierConfig, - reload_and_populate_configs, - reload_and_populate_models, ) from speculators.proposals import GreedyTokenProposalConfig @@ -58,11 +56,6 @@ def forward(self, *args, **kwargs): return {"logits": torch.randn(1, 10, 1000)} -# Reload registries to include test classes -reload_and_populate_configs() -reload_and_populate_models() - - @pytest.fixture def speculator_model_test_config(): return SpeculatorModelTestConfig( @@ -96,9 +89,9 @@ def test_speculator_model_class_attributes(): @pytest.mark.smoke def test_speculator_model_registry_contains_test_model(): - assert SpeculatorModel.registry is not None - assert "test_speculator" in SpeculatorModel.registry - assert SpeculatorModel.registry["test_speculator"] == SpeculatorTestModel + assert SpeculatorModel.registry is not None # type: ignore[misc] + assert "test_speculator" in SpeculatorModel.registry # type: ignore[misc] + assert SpeculatorModel.registry["test_speculator"] == SpeculatorTestModel # type: ignore[misc] @pytest.mark.smoke diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py index 77640a8d..dcae83c3 100644 --- a/tests/unit/utils/test_auto_importer.py +++ b/tests/unit/utils/test_auto_importer.py @@ -1,196 +1,266 @@ """ -Unit tests for the auto_importer module in the Speculators library. +Unit tests for the auto_importer module. """ +from __future__ import annotations + from unittest import mock import pytest -from speculators.utils.auto_importer import AutoImporterMixin - -# ===== Basic Functionality Tests ===== - - -@pytest.mark.smoke -def test_auto_importer_initialization(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" - - assert AutoImporterMixin.auto_package is None - assert AutoImporterMixin.auto_ignore_modules is None - assert AutoImporterMixin.auto_imported_modules is None - - -@pytest.mark.smoke -def test_auto_importer_subclass_attributes(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" - - assert TestAutoImporterClass.auto_package == "test_package.modules" - assert TestAutoImporterClass.auto_ignore_modules is None - assert TestAutoImporterClass.auto_imported_modules is None - - -@pytest.mark.smoke -def test_no_package_raises_error(): - class TestAutoImporterClass(AutoImporterMixin): ... - - with pytest.raises(ValueError) as exc_info: - TestAutoImporterClass.auto_import_package_modules() +from speculators.utils import AutoImporterMixin + + +class TestAutoImporterMixin: + """Test suite for AutoImporterMixin functionality.""" + + @pytest.fixture( + params=[ + { + "auto_package": "test.package", + "auto_ignore_modules": None, + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module1", "test.package.module2"], + }, + { + "auto_package": ("test.package1", "test.package2"), + "auto_ignore_modules": None, + "modules": [ + ("test.package1.moduleA", False), + ("test.package2.moduleB", False), + ], + "expected_imports": ["test.package1.moduleA", "test.package2.moduleB"], + }, + { + "auto_package": "test.package", + "auto_ignore_modules": ("test.package.module1",), + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module2"], + }, + ], + ids=["single_package", "multiple_packages", "ignored_modules"], + ) + def valid_instances(self, request): + """Fixture providing test data for AutoImporterMixin subclasses.""" + config = request.param + + class TestClass(AutoImporterMixin): + auto_package = config["auto_package"] + auto_ignore_modules = config["auto_ignore_modules"] + + return TestClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test AutoImporterMixin class signatures and attributes.""" + assert hasattr(AutoImporterMixin, "auto_package") + assert hasattr(AutoImporterMixin, "auto_ignore_modules") + assert hasattr(AutoImporterMixin, "auto_imported_modules") + assert hasattr(AutoImporterMixin, "auto_import_package_modules") + assert callable(AutoImporterMixin.auto_import_package_modules) + + # Test default class variables + assert AutoImporterMixin.auto_package is None + assert AutoImporterMixin.auto_ignore_modules is None + assert AutoImporterMixin.auto_imported_modules is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test AutoImporterMixin subclass initialization.""" + test_class, config = valid_instances + assert issubclass(test_class, AutoImporterMixin) + assert test_class.auto_package == config["auto_package"] + assert test_class.auto_ignore_modules == config["auto_ignore_modules"] + assert test_class.auto_imported_modules is None + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test AutoImporterMixin with missing auto_package.""" + + class TestClass(AutoImporterMixin): + pass + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestClass.auto_import_package_modules() + + @pytest.mark.smoke + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules(self, mock_walk, mock_import, valid_instances): + """Test auto_import_package_modules core functionality.""" + test_class, config = valid_instances + + # Setup mocks based on config + packages = {} + modules = {} + + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + pkg_path = pkg.replace(".", "/") + packages[pkg] = MockHelper.create_mock_package(pkg, pkg_path) + else: + pkg = config["auto_package"] + packages[pkg] = MockHelper.create_mock_package(pkg, pkg.replace(".", "/")) + + for module_name, is_pkg in config["modules"]: + if not is_pkg: + modules[module_name] = MockHelper.create_mock_module(module_name) + + mock_import.side_effect = lambda name: {**packages, **modules}.get( + name, mock.MagicMock() + ) + + def walk_side_effect(path, prefix): + return [ + (None, module_name, is_pkg) + for module_name, is_pkg in config["modules"] + if module_name.startswith(prefix) + ] + + mock_walk.side_effect = walk_side_effect + + # Execute + test_class.auto_import_package_modules() + + # Verify + assert test_class.auto_imported_modules == config["expected_imports"] + + # Verify package imports + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + mock_import.assert_any_call(pkg) + else: + mock_import.assert_any_call(config["auto_package"]) + + # Verify expected module imports + for expected_module in config["expected_imports"]: + mock_import.assert_any_call(expected_module) + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules_invalid(self, mock_walk, mock_import): + """Test auto_import_package_modules with invalid configurations.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Test import error handling + mock_import.side_effect = ImportError("Module not found") + + with pytest.raises(ImportError): + TestClass.auto_import_package_modules() + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_packages(self, mock_walk, mock_import): + """Test that packages (is_pkg=True) are skipped.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.subpackage", True), + (None, "test.package.module", False), + ] - assert "auto_package" in str(exc_info.value) - assert "must be set" in str(exc_info.value) + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + mock_import.assert_any_call("test.package.module") + # subpackage should not be imported + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.subpackage") + + @pytest.mark.sanity + @mock.patch("sys.modules", {"test.package.existing": mock.MagicMock()}) + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_already_imported_modules(self, mock_walk, mock_import): + """Test that modules already in sys.modules are tracked but not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + }.get(name, mock.MagicMock()) + + mock_walk.return_value = [ + (None, "test.package.existing", False), + ] + # Execute + TestClass.auto_import_package_modules() -# ===== Module Import Tests ===== + # Verify + assert TestClass.auto_imported_modules == ["test.package.existing"] + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_prevent_duplicate_module_imports(self, mock_walk, mock_import): + """Test that modules already in auto_imported_modules are not re-imported.""" -@pytest.mark.smoke -def test_single_package_import(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" + class TestClass(AutoImporterMixin): + auto_package = "test.package" - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package = mock.MagicMock() - mock_package.__path__ = ["test_package/modules"] - mock_package.__name__ = "test_package.modules" + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") - def import_module(name: str): - if name == "test_package.modules": - return mock_package - elif name == "test_package.modules.module1": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module1" - return module - elif name == "test_package.modules.module2": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module2" - return module - else: - raise ImportError(f"No module named {name}") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] - def walk_packages(package_path, package_name): - if package_name == "test_package.modules.": - return [ - (None, "test_package.modules.module1", False), - (None, "test_package.modules.module2", False), - ] - else: - raise ValueError(f"Unknown package: {package_name}") + mock_walk.return_value = [ + (None, "test.package.module", False), + (None, "test.package.module", False), + ] - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - TestAutoImporterClass.auto_import_package_modules() + # Execute + TestClass.auto_import_package_modules() - mock_import.assert_any_call("test_package.modules") - assert TestAutoImporterClass.auto_imported_modules == [ - "test_package.modules.module1", - "test_package.modules.module2", - ] + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + assert mock_import.call_count == 2 # Package + module (not duplicate) -@pytest.mark.sanity -def test_multiple_package_import(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = ("test_package.modules1", "test_package.modules2") - - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package1 = mock.MagicMock() - mock_package1.__path__ = ["test_package/modules1"] - mock_package1.__name__ = "test_package.modules1" - - mock_package2 = mock.MagicMock() - mock_package2.__path__ = ["test_package/modules2"] - mock_package2.__name__ = "test_package.modules2" - - def import_module(name: str): - if name == "test_package.modules1": - return mock_package1 - elif name == "test_package.modules2": - return mock_package2 - elif name == "test_package.modules1.moduleA": - module = mock.MagicMock() - module.__name__ = "test_package.modules1.moduleA" - return module - elif name == "test_package.modules2.moduleB": - module = mock.MagicMock() - module.__name__ = "test_package.modules2.moduleB" - return module - else: - raise ImportError(f"No module named {name}") - - def walk_packages(package_path, package_name): - if package_name == "test_package.modules1.": - return [ - (None, "test_package.modules1.moduleA", False), - ] - elif package_name == "test_package.modules2.": - return [ - (None, "test_package.modules2.moduleB", False), - ] - else: - raise ValueError(f"Unknown package: {package_name}") - - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - TestAutoImporterClass.auto_import_package_modules() - - assert TestAutoImporterClass.auto_imported_modules == [ - "test_package.modules1.moduleA", - "test_package.modules2.moduleB", - ] +class MockHelper: + """Helper class to create consistent mock objects for testing.""" + @staticmethod + def create_mock_package(name: str, path: str): + """Create a mock package with required attributes.""" + package = mock.MagicMock() + package.__name__ = name + package.__path__ = [path] + return package -@pytest.mark.sanity -def test_ignore_modules(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" - auto_ignore_modules = ("test_package.modules.module1",) - - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package = mock.MagicMock() - mock_package.__path__ = ["test_package/modules"] - mock_package.__name__ = "test_package.modules" - - def import_module(name: str): - if name == "test_package.modules": - return mock_package - elif name == "test_package.modules.module1": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module1" - return module - elif name == "test_package.modules.module2": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module2" - return module - else: - raise ImportError(f"No module named {name}") - - def walk_packages(package_path, package_name): - if package_name == "test_package.modules.": - return [ - (None, "test_package.modules.module1", False), - (None, "test_package.modules.module2", False), - ] - else: - raise ValueError(f"Unknown package: {package_name}") - - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - TestAutoImporterClass.auto_import_package_modules() - - assert TestAutoImporterClass.auto_imported_modules == [ - "test_package.modules.module2", - ] + @staticmethod + def create_mock_module(name: str): + """Create a mock module with required attributes.""" + module = mock.MagicMock() + module.__name__ = name + return module diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py index de05dadc..2c8a6e57 100644 --- a/tests/unit/utils/test_pydantic_utils.py +++ b/tests/unit/utils/test_pydantic_utils.py @@ -1,245 +1,626 @@ """ -Unit tests for the pydantic_utils module in the Speculators library. +Unit tests for the pydantic_utils module. """ -from typing import ClassVar +from __future__ import annotations + +from typing import ClassVar, TypeVar from unittest import mock import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field, ValidationError from speculators.utils import PydanticClassRegistryMixin, ReloadableBaseModel - -# ===== ReloadableBaseModel Tests ===== - - -@pytest.mark.smoke -def test_reloadable_base_model_initialization(): - class TestModel(ReloadableBaseModel): - name: str - - model = TestModel(name="test") - assert model.name == "test" - - -@pytest.mark.smoke -def test_reloadable_base_model_reload_schema(): - class TestModel(ReloadableBaseModel): - name: str - - model = TestModel(name="test") - assert model.name == "test" - - # Mock the model_rebuild method to simulate schema reload - with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: - TestModel.reload_schema() - mock_rebuild.assert_called_once() - - -# ===== PydanticClassRegistryMixin Tests ===== - - -@pytest.mark.smoke -def test_pydantic_class_registry_subclass_init(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - return cls - - assert TestBaseModel.registry is None - assert TestBaseModel.schema_discriminator == "test_type" - - -@pytest.mark.smoke -def test_pydantic_class_registry_subclass_missing_base_type(): - class InvalidBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - with pytest.raises(TypeError): - InvalidBaseModel(test_type="test") # type: ignore[abstract] - - -@pytest.mark.sanity -def test_pydantic_class_registry_decorator(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register() - class TestSubModel(TestBaseModel): - test_type: str = "TestSubModel" - value: str - - assert TestBaseModel.registry is not None - assert "TestSubModel" in TestBaseModel.registry - assert TestBaseModel.registry["TestSubModel"] is TestSubModel - - -@pytest.mark.sanity -def test_pydantic_class_registry_decorator_with_name(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register("custom_name") - class TestSubModel(TestBaseModel): - test_type: str = "custom_name" - value: str - - assert TestBaseModel.registry is not None - assert "custom_name" in TestBaseModel.registry - assert TestBaseModel.registry["custom_name"] is TestSubModel - - -@pytest.mark.smoke -def test_pydantic_class_registry_decorator_invalid_type(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - class RegularClass: - pass - - with pytest.raises(TypeError) as exc_info: - TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] - - assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) +from speculators.utils.pydantic_utils import BaseModelT, RegisterClassT @pytest.mark.smoke -def test_pydantic_class_registry_subclass_marshalling(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register("test_sub") - class TestSubModel(TestBaseModel): - test_type: str = "test_sub" - value: str - - TestBaseModel.reload_schema() - - # Test direct construction of subclass - sub_instance = TestSubModel(value="test_value") - assert isinstance(sub_instance, TestSubModel) - assert sub_instance.test_type == "test_sub" - assert sub_instance.value == "test_value" - - # Test serialization with model_dump - dump_data = sub_instance.model_dump() - assert isinstance(dump_data, dict) - assert dump_data["test_type"] == "test_sub" - assert dump_data["value"] == "test_value" - - # Test deserialization via model_validate - recreated = TestSubModel.model_validate(dump_data) - assert isinstance(recreated, TestSubModel) - assert recreated.test_type == "test_sub" - assert recreated.value == "test_value" - - # Test polymorphic deserialization via base class - recreated = TestBaseModel.model_validate(dump_data) # type: ignore[assignment] - assert isinstance(recreated, TestSubModel) - assert recreated.test_type == "test_sub" - assert recreated.value == "test_value" +def test_base_model_t(): + """Test that BaseModelT is configured correctly as a TypeVar.""" + assert isinstance(BaseModelT, type(TypeVar("test"))) + assert BaseModelT.__name__ == "BaseModelT" + assert BaseModelT.__bound__ is BaseModel + assert BaseModelT.__constraints__ == () @pytest.mark.smoke -def test_pydantic_class_registry_parent_class_marshalling(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @classmethod - def __pydantic_generate_base_schema__(cls, handler): - return handler(cls) - - @TestBaseModel.register("sub_a") - class TestSubModelA(TestBaseModel): - test_type: str = "sub_a" - value_a: str - - @TestBaseModel.register("sub_b") - class TestSubModelB(TestBaseModel): - test_type: str = "sub_b" - value_b: int - - class ContainerModel(BaseModel): - name: str - model: TestBaseModel - models: list[TestBaseModel] - - sub_a = TestSubModelA(value_a="test") - sub_b = TestSubModelB(value_b=123) - - container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) - assert isinstance(container.model, TestSubModelA) - assert container.model.test_type == "sub_a" - assert container.model.value_a == "test" - assert isinstance(container.models[0], TestSubModelA) - assert isinstance(container.models[1], TestSubModelB) - assert container.models[0].test_type == "sub_a" - assert container.models[1].test_type == "sub_b" - assert container.models[0].value_a == "test" - assert container.models[1].value_b == 123 - - # Test serialization with model_dump - dump_data = container.model_dump() - assert isinstance(dump_data, dict) - assert dump_data["name"] == "container" - assert dump_data["model"]["test_type"] == "sub_a" - assert dump_data["model"]["value_a"] == "test" - assert len(dump_data["models"]) == 2 - assert dump_data["models"][0]["test_type"] == "sub_a" - assert dump_data["models"][0]["value_a"] == "test" - assert dump_data["models"][1]["test_type"] == "sub_b" - assert dump_data["models"][1]["value_b"] == 123 - - # Test deserialization via model_validate - recreated = ContainerModel.model_validate(dump_data) - assert isinstance(recreated, ContainerModel) - assert recreated.name == "container" - assert isinstance(recreated.model, TestSubModelA) - assert recreated.model.test_type == "sub_a" - assert recreated.model.value_a == "test" - assert len(recreated.models) == 2 - assert isinstance(recreated.models[0], TestSubModelA) - assert isinstance(recreated.models[1], TestSubModelB) - assert recreated.models[0].test_type == "sub_a" - assert recreated.models[1].test_type == "sub_b" - assert recreated.models[0].value_a == "test" - assert recreated.models[1].value_b == 123 +def test_register_class_t(): + """Test that RegisterClassT is configured correctly as a TypeVar.""" + assert isinstance(RegisterClassT, type(TypeVar("test"))) + assert RegisterClassT.__name__ == "RegisterClassT" + assert RegisterClassT.__bound__ is not None + assert RegisterClassT.__constraints__ == () + + +class TestReloadableBaseModel: + """Test suite for ReloadableBaseModel.""" + + @pytest.fixture( + params=[ + {"name": "test_value"}, + {"name": "hello_world"}, + {"name": "another_test"}, + ], + ids=["basic_string", "multi_word", "underscore"], + ) + def valid_instances(self, request) -> tuple[ReloadableBaseModel, dict[str, str]]: + """Fixture providing test data for ReloadableBaseModel.""" + + class TestModel(ReloadableBaseModel): + name: str + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ReloadableBaseModel inheritance and class variables.""" + assert issubclass(ReloadableBaseModel, BaseModel) + assert hasattr(ReloadableBaseModel, "model_config") + assert hasattr(ReloadableBaseModel, "reload_schema") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ReloadableBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ReloadableBaseModel) + assert instance.name == constructor_args["name"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("name", None), + ("name", 123), + ("name", []), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test ReloadableBaseModel with invalid field values.""" + + class TestModel(ReloadableBaseModel): + name: str + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test ReloadableBaseModel initialization without required field.""" + + class TestModel(ReloadableBaseModel): + name: str + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_reload_schema(self): + """Test ReloadableBaseModel.reload_schema method.""" + + class TestModel(ReloadableBaseModel): + name: str + + # Mock the model_rebuild method to simulate schema reload + with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: + TestModel.reload_schema() + mock_rebuild.assert_called_once_with(force=True) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test ReloadableBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["name"] == constructor_args["name"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.name == constructor_args["name"] + + +class TestPydanticClassRegistryMixin: + """Test suite for PydanticClassRegistryMixin.""" + + @pytest.fixture( + params=[ + {"test_type": "test_sub", "value": "test_value"}, + {"test_type": "test_sub", "value": "hello_world"}, + ], + ids=["basic_value", "multi_word"], + ) + def valid_instances( + self, request + ) -> tuple[PydanticClassRegistryMixin, dict, type, type]: + """Fixture providing test data for PydanticClassRegistryMixin.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + TestBaseModel.reload_schema() + + constructor_args = request.param + instance = TestSubModel(value=constructor_args["value"]) + return instance, constructor_args, TestBaseModel, TestSubModel + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticClassRegistryMixin inheritance and class variables.""" + assert issubclass(PydanticClassRegistryMixin, ReloadableBaseModel) + assert hasattr(PydanticClassRegistryMixin, "schema_discriminator") + assert PydanticClassRegistryMixin.schema_discriminator == "model_type" + assert hasattr(PydanticClassRegistryMixin, "register_decorator") + assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__") + assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__") + assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry") + assert hasattr(PydanticClassRegistryMixin, "registered_classes") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test PydanticClassRegistryMixin initialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + assert isinstance(instance, sub_class) + assert isinstance(instance, base_class) + assert instance.test_type == constructor_args["test_type"] + assert instance.value == constructor_args["value"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("test_type", None), + ("test_type", 123), + ("value", None), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test PydanticClassRegistryMixin with invalid field values.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + data = {field: value} + if field == "test_type": + data["value"] = "test" + else: + data["test_type"] = "test_sub" + + with pytest.raises(ValidationError): + TestSubModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test PydanticClassRegistryMixin initialization without required field.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + with pytest.raises(ValidationError): + TestSubModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_register_decorator(self): + """Test PydanticClassRegistryMixin.register_decorator method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @TestBaseModel.register() + class TestSubModel(TestBaseModel): + test_type: str = "TestSubModel" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "TestSubModel" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["TestSubModel"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_with_name(self): + """Test PydanticClassRegistryMixin.register_decorator with custom name.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @TestBaseModel.register("custom_name") + class TestSubModel(TestBaseModel): + test_type: str = "custom_name" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "custom_name" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["custom_name"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_invalid_type(self): + """Test PydanticClassRegistryMixin.register_decorator with invalid type.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + class RegularClass: + pass + + with pytest.raises(TypeError) as exc_info: + TestBaseModel.register_decorator(RegularClass) # type: ignore[type-var] + + assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test PydanticClassRegistryMixin.auto_populate_registry method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + with ( + mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, + mock.patch( + "speculators.utils.registry.RegistryMixin.auto_populate_registry", + return_value=True, + ), + ): + result = TestBaseModel.auto_populate_registry() + assert result is True + mock_reload.assert_called_once() + + @pytest.mark.smoke + def test_registered_classes(self): + """Test PydanticClassRegistryMixin.registered_classes method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = False + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @TestBaseModel.register("test_sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "test_sub_a" + value_a: str + + @TestBaseModel.register("test_sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "test_sub_b" + value_b: int + + # Test normal case with registered classes + registered = TestBaseModel.registered_classes() + assert isinstance(registered, tuple) + assert len(registered) == 2 + assert TestSubModelA in registered + assert TestSubModelB in registered + + @pytest.mark.sanity + def test_registered_classes_with_auto_discovery(self): + """Test PydanticClassRegistryMixin.registered_classes with auto discovery.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + with mock.patch.object( + TestBaseModel, "auto_populate_registry" + ) as mock_auto_populate: + # Mock the registry to simulate registered classes + TestBaseModel.registry = {"test_class": type("TestClass", (), {})} # type: ignore[misc] + mock_auto_populate.return_value = False + + registered = TestBaseModel.registered_classes() + mock_auto_populate.assert_called_once() + assert isinstance(registered, tuple) + assert len(registered) == 1 + + @pytest.mark.sanity + def test_registered_classes_no_registry(self): + """Test PydanticClassRegistryMixin.registered_classes with no registry.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + # Ensure registry is None + TestBaseModel.registry = None # type: ignore[misc] + + with pytest.raises(ValueError) as exc_info: + TestBaseModel.registered_classes() + + assert "must be called after registering classes" in str(exc_info.value) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test PydanticClassRegistryMixin serialization and deserialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + + # Test serialization with model_dump + dump_data = instance.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["test_type"] == constructor_args["test_type"] + assert dump_data["value"] == constructor_args["value"] + + # Test deserialization via subclass + recreated = sub_class.model_validate(dump_data) + assert isinstance(recreated, sub_class) + assert recreated.test_type == constructor_args["test_type"] + assert recreated.value == constructor_args["value"] + + # Test polymorphic deserialization via base class + recreated_base = base_class.model_validate(dump_data) # type: ignore[assignment] + assert isinstance(recreated_base, sub_class) + assert recreated_base.test_type == constructor_args["test_type"] + assert recreated_base.value == constructor_args["value"] + + @pytest.mark.regression + def test_polymorphic_container_marshalling(self): + """Test PydanticClassRegistryMixin in container models.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @classmethod + def __pydantic_generate_base_schema__(cls, handler): + return handler(cls) + + @TestBaseModel.register("sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "sub_a" + value_a: str + + @TestBaseModel.register("sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "sub_b" + value_b: int + + class ContainerModel(BaseModel): + name: str + model: TestBaseModel + models: list[TestBaseModel] + + sub_a = TestSubModelA(value_a="test") + sub_b = TestSubModelB(value_b=123) + + container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) + + # Verify container construction + assert isinstance(container.model, TestSubModelA) + assert container.model.test_type == "sub_a" + assert container.model.value_a == "test" + assert len(container.models) == 2 + assert isinstance(container.models[0], TestSubModelA) + assert isinstance(container.models[1], TestSubModelB) + + # Test serialization + dump_data = container.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["name"] == "container" + assert dump_data["model"]["test_type"] == "sub_a" + assert dump_data["model"]["value_a"] == "test" + assert len(dump_data["models"]) == 2 + assert dump_data["models"][0]["test_type"] == "sub_a" + assert dump_data["models"][1]["test_type"] == "sub_b" + + # Test deserialization + recreated = ContainerModel.model_validate(dump_data) + assert isinstance(recreated, ContainerModel) + assert recreated.name == "container" + assert isinstance(recreated.model, TestSubModelA) + assert len(recreated.models) == 2 + assert isinstance(recreated.models[0], TestSubModelA) + assert isinstance(recreated.models[1], TestSubModelB) + + @pytest.mark.smoke + def test_register_preserves_pydantic_metadata(self): # noqa: C901 + """Test that registered Pydantic classes retain docs, types, and methods.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "model_type" + model_type: str + + @classmethod + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" + + @TestBaseModel.register("documented_model") + class DocumentedModel(TestBaseModel): + """This is a documented Pydantic model with methods and type hints.""" + + model_type: str = "documented_model" + value: int = Field(description="An integer value for the model") + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedModel: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedModel instance + """ + return cls(value=int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + def model_post_init(self, __context) -> None: + """Post-initialization processing. + + :param __context: Validation context + """ + if self.value < 0: + raise ValueError("Value must be non-negative") + + # Check that the class was registered + assert TestBaseModel.is_registered("documented_model") + registered_class = TestBaseModel.get_registered_object("documented_model") + assert registered_class is DocumentedModel + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented Pydantic model with methods" in registered_class.__doc__ + + # Check that methods retain their documentation + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) + assert registered_class.model_post_init.__doc__ is not None + assert ( + "Post-initialization processing" in registered_class.model_post_init.__doc__ + ) + + # Check that methods are callable and work correctly + instance = DocumentedModel(value=42) + assert isinstance(instance, DocumentedModel) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + assert instance.model_type == "documented_model" + + # Check class methods work + instance2 = DocumentedModel.from_string("123") + assert instance2.get_value() == 123 + assert instance2.model_type == "documented_model" + + # Check static methods work + assert DocumentedModel.validate_value(10) is True + assert DocumentedModel.validate_value(-5) is False + + # Check that Pydantic functionality is preserved + data_dict = instance.model_dump() + assert data_dict["value"] == 100 + assert data_dict["model_type"] == "documented_model" + + recreated = DocumentedModel.model_validate(data_dict) + assert isinstance(recreated, DocumentedModel) + assert recreated.value == 100 + assert recreated.model_type == "documented_model" + + # Test field validation + with pytest.raises(ValidationError): + DocumentedModel(value="not_an_int") # type: ignore[arg-type] + + # Test post_init validation + with pytest.raises(ValueError, match="Value must be non-negative"): + DocumentedModel(value=-10) + + # Check that Pydantic field metadata is preserved + value_field = DocumentedModel.model_fields["value"] + assert value_field.description == "An integer value for the model" + + # Check that type annotations are preserved (if accessible) + import inspect + + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(DocumentedModel.get_value) + return_ann = annotations.get("return") + assert return_ann is int or return_ann == "int" + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass + + # Check that the class name is preserved + assert DocumentedModel.__name__ == "DocumentedModel" + assert DocumentedModel.__qualname__.endswith("DocumentedModel") + + # Verify that the class is still properly integrated with the registry system + all_registered = TestBaseModel.registered_classes() + assert DocumentedModel in all_registered + + # Test that the registered class is the same as the original + assert registered_class is DocumentedModel diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py index e653c83e..393ef4ba 100644 --- a/tests/unit/utils/test_registry.py +++ b/tests/unit/utils/test_registry.py @@ -1,310 +1,589 @@ """ -Unit tests for the registry module in the Speculators library. +Unit tests for the registry module. """ +from __future__ import annotations + +import inspect +from typing import TypeVar from unittest import mock import pytest -from speculators.utils.registry import ClassRegistryMixin - -# ===== ClassRegistryMixin Tests ===== - - -@pytest.mark.smoke -def test_class_registry_initialization(): - class TestRegistryClass(ClassRegistryMixin): - pass - - assert TestRegistryClass.registry is None - - -@pytest.mark.smoke -def test_register_with_name(): - class TestRegistryClass(ClassRegistryMixin): - pass - - @TestRegistryClass.register("custom_name") - class TestClass: - pass - - assert TestRegistryClass.registry is not None - assert "custom_name" in TestRegistryClass.registry - assert TestRegistryClass.registry["custom_name"] is TestClass - - -@pytest.mark.smoke -def test_register_without_name(): - class TestRegistryClass(ClassRegistryMixin): - pass - - @TestRegistryClass.register() - class TestClass: - pass - - assert TestRegistryClass.registry is not None - assert "TestClass" in TestRegistryClass.registry - assert TestRegistryClass.registry["TestClass"] is TestClass - - -@pytest.mark.smoke -def test_register_decorator_direct(): - class TestRegistryClass(ClassRegistryMixin): - pass +from speculators.utils import RegistryMixin +from speculators.utils.registry import RegisterT, RegistryObjT + + +def test_registry_obj_type(): + """Test that RegistryObjT is configured correctly as a TypeVar.""" + assert isinstance(RegistryObjT, type(TypeVar("test"))) + assert RegistryObjT.__name__ == "RegistryObjT" + assert RegistryObjT.__bound__ is None + assert RegistryObjT.__constraints__ == () + + +def test_registered_type(): + """Test that RegisterT is configured correctly as a TypeVar.""" + assert isinstance(RegisterT, type(TypeVar("test"))) + assert RegisterT.__name__ == "RegisterT" + assert RegisterT.__bound__ is None + assert RegisterT.__constraints__ == () + + +class TestRegistryMixin: + """Test suite for RegistryMixin class.""" + + @pytest.fixture( + params=[ + {"registry_auto_discovery": False, "auto_package": None}, + {"registry_auto_discovery": True, "auto_package": "test.package"}, + ], + ids=["manual_registry", "auto_discovery"], + ) + def valid_instances(self, request): + """Fixture providing test data for RegistryMixin subclasses.""" + config = request.param + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = config["registry_auto_discovery"] + if config["auto_package"]: + auto_package = config["auto_package"] + + return TestRegistryClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RegistryMixin inheritance and exposed methods.""" + assert hasattr(RegistryMixin, "registry") + assert hasattr(RegistryMixin, "registry_auto_discovery") + assert hasattr(RegistryMixin, "registry_populated") + assert hasattr(RegistryMixin, "register") + assert hasattr(RegistryMixin, "register_decorator") + assert hasattr(RegistryMixin, "auto_populate_registry") + assert hasattr(RegistryMixin, "registered_objects") + assert hasattr(RegistryMixin, "is_registered") + assert hasattr(RegistryMixin, "get_registered_object") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test RegistryMixin initialization.""" + registry_class, config = valid_instances + + assert registry_class.registry is None + assert ( + registry_class.registry_auto_discovery == config["registry_auto_discovery"] + ) + assert registry_class.registry_populated is False + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test RegistryMixin with missing auto_package when auto_discovery enabled.""" + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = True + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestRegistryClass.auto_import_package_modules() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), + ], + ) + def test_register(self, valid_instances, name, expected_key): + """Test register method with various name configurations.""" + registry_class, _ = valid_instances + + @registry_class.register(name) + class TestClass: + pass - @TestRegistryClass.register_decorator - class TestClass: - pass + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_invalid(self, valid_instances, invalid_name): + """Test register method with invalid name types.""" + registry_class, _ = valid_instances + + # The register method returns a decorator, so we need to apply it to test + # validation + decorator = registry_class.register(invalid_name) + + class TestClass: + pass - assert TestRegistryClass.registry is not None - assert "TestClass" in TestRegistryClass.registry - assert TestRegistryClass.registry["TestClass"] is TestClass + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + decorator(TestClass) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), + ], + ) + def test_register_decorator(self, valid_instances, name, expected_key): + """Test register_decorator method with various name configurations.""" + registry_class, _ = valid_instances + + class TestClass: + pass + registry_class.register_decorator(TestClass, name=name) + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_decorator_invalid(self, valid_instances, invalid_name): + """Test register_decorator with invalid name types.""" + registry_class, _ = valid_instances + + class TestClass: + pass -@pytest.mark.sanity -def test_register_invalid_name_type(): - class TestRegistryClass(ClassRegistryMixin): - pass + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + registry_class.register_decorator(TestClass, name=invalid_name) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test auto_populate_registry method with valid configuration.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test.package" + + with mock.patch.object( + TestAutoRegistry, "auto_import_package_modules" + ) as mock_import: + result = TestAutoRegistry.auto_populate_registry() + assert result is True + mock_import.assert_called_once() + assert TestAutoRegistry.registry_populated is True + + # Second call should return False + result = TestAutoRegistry.auto_populate_registry() + assert result is False + mock_import.assert_called_once() + + @pytest.mark.sanity + def test_auto_populate_registry_invalid(self): + """Test auto_populate_registry when auto-discovery is disabled.""" + + class TestDisabledRegistry(RegistryMixin): + registry_auto_discovery = False + + with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): + TestDisabledRegistry.auto_populate_registry() + + @pytest.mark.smoke + def test_registered_objects(self, valid_instances): + """Test registered_objects method with manual registration.""" + registry_class, config = valid_instances + + @registry_class.register("class1") + class TestClass1: + pass - with pytest.raises(ValueError) as exc_info: - TestRegistryClass.register(123) # type: ignore[arg-type] + @registry_class.register("class2") + class TestClass2: + pass - assert "name must be a string or None" in str(exc_info.value) + if config["registry_auto_discovery"]: + with mock.patch.object(registry_class, "auto_import_package_modules"): + objects = registry_class.registered_objects() + else: + objects = registry_class.registered_objects() + assert isinstance(objects, tuple) + assert len(objects) == 2 + assert TestClass1 in objects + assert TestClass2 in objects -@pytest.mark.sanity -def test_register_decorator_invalid_class(): - class TestRegistryClass(ClassRegistryMixin): - pass + @pytest.mark.sanity + def test_registered_objects_invalid(self): + """Test registered_objects when no objects are registered.""" - with pytest.raises(TypeError) as exc_info: - TestRegistryClass.register_decorator("not_a_class") # type: ignore[arg-type] + class TestRegistryClass(RegistryMixin): + pass - assert "must be used as a class decorator" in str(exc_info.value) + with pytest.raises( + ValueError, match="must be called after registering objects" + ): + TestRegistryClass.registered_objects() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "check_name", "expected"), + [ + ("test_name", "test_name", True), + ("TestName", "testname", True), + ("UPPERCASE", "uppercase", True), + ("test_name", "nonexistent", False), + ], + ) + def test_is_registered(self, valid_instances, register_name, check_name, expected): + """Test is_registered with various name combinations.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + result = registry_class.is_registered(check_name) + assert result == expected + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "lookup_name"), + [ + ("test_name", "test_name"), + ("TestName", "testname"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_get_registered_object(self, valid_instances, register_name, lookup_name): + """Test get_registered_object with valid names.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass -@pytest.mark.sanity -def test_register_decorator_invalid_name(): - class TestRegistryClass(ClassRegistryMixin): - pass + result = registry_class.get_registered_object(lookup_name) + assert result is TestClass - class TestClass: - pass + @pytest.mark.sanity + @pytest.mark.parametrize( + "lookup_name", + ["nonexistent", "wrong_name", "DIFFERENT_CASE"], + ) + def test_get_registered_object_invalid(self, valid_instances, lookup_name): + """Test get_registered_object with invalid names.""" + registry_class, _ = valid_instances - with pytest.raises(ValueError) as exc_info: - TestRegistryClass.register_decorator(TestClass, name=123) # type: ignore[arg-type] + @registry_class.register("valid_name") + class TestClass: + pass - assert "must be used as a class decorator" in str(exc_info.value) + result = registry_class.get_registered_object(lookup_name) + assert result is None + @pytest.mark.regression + def test_multiple_registries_isolation(self): + """Test that different registry classes maintain separate registries.""" -@pytest.mark.sanity -def test_register_duplicate_name(): - class TestRegistryClass(ClassRegistryMixin): - pass + class Registry1(RegistryMixin): + pass - @TestRegistryClass.register("test_name") - class TestClass1: - pass + class Registry2(RegistryMixin): + pass - with pytest.raises(ValueError) as exc_info: + @Registry1.register() + class TestClass1: + pass - @TestRegistryClass.register("test_name") + @Registry2.register() class TestClass2: pass - assert "already registered" in str(exc_info.value) - - -@pytest.mark.sanity -def test_registered_classes_empty(): - class TestRegistryClass(ClassRegistryMixin): - pass - - with pytest.raises(ValueError) as exc_info: - TestRegistryClass.registered_classes() - - assert "must be called after registering classes" in str(exc_info.value) - - -@pytest.mark.sanity -def test_registered_classes(): - class TestRegistryClass(ClassRegistryMixin): - pass - - @TestRegistryClass.register() - class TestClass1: - pass - - @TestRegistryClass.register("custom_name") - class TestClass2: - pass + assert Registry1.registry is not None # type: ignore[misc] + assert Registry2.registry is not None # type: ignore[misc] + assert Registry1.registry != Registry2.registry # type: ignore[misc] + assert "TestClass1" in Registry1.registry # type: ignore[misc] + assert "TestClass2" in Registry2.registry # type: ignore[misc] + assert "TestClass1" not in Registry2.registry # type: ignore[misc] + assert "TestClass2" not in Registry1.registry # type: ignore[misc] + + @pytest.mark.smoke + def test_auto_discovery_initialization(self): + """Test initialization of auto-discovery enabled registry.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + assert TestAutoRegistry.registry is None # type: ignore[misc] + assert TestAutoRegistry.registry_populated is False + assert TestAutoRegistry.auto_package == "test_package.modules" + assert TestAutoRegistry.registry_auto_discovery is True + + @pytest.mark.smoke + def test_auto_discovery_registered_objects(self): + """Test automatic population during registered_objects call.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with mock.patch.object( + TestAutoRegistry, "auto_populate_registry" + ) as mock_populate: + TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} # type: ignore[misc] + objects = TestAutoRegistry.registered_objects() + mock_populate.assert_called_once() + assert objects == ("obj1", "obj2") + + @pytest.mark.sanity + def test_register_duplicate_registration(self, valid_instances): + """Test register method with duplicate names.""" + registry_class, _ = valid_instances + + @registry_class.register("duplicate_name") + class TestClass1: + pass - registered = TestRegistryClass.registered_classes() - assert isinstance(registered, tuple) - assert len(registered) == 2 - assert TestClass1 in registered - assert TestClass2 in registered + with pytest.raises(ValueError, match="already registered"): + @registry_class.register("duplicate_name") + class TestClass2: + pass -@pytest.mark.regression -def test_multiple_registries_isolation(): - class Registry1(ClassRegistryMixin): - pass + @pytest.mark.sanity + def test_register_decorator_duplicate_registration(self, valid_instances): + """Test register_decorator with duplicate names.""" + registry_class, _ = valid_instances - class Registry2(ClassRegistryMixin): - pass + class TestClass1: + pass - @Registry1.register() - class TestClass1: - pass + class TestClass2: + pass - @Registry2.register() - class TestClass2: - pass + registry_class.register_decorator(TestClass1, name="duplicate_name") + with pytest.raises(ValueError, match="already registered"): + registry_class.register_decorator(TestClass2, name="duplicate_name") - assert Registry1.registry is not None - assert Registry2.registry is not None - assert Registry1.registry != Registry2.registry - assert "TestClass1" in Registry1.registry - assert "TestClass2" in Registry2.registry - assert "TestClass1" not in Registry2.registry - assert "TestClass2" not in Registry1.registry + @pytest.mark.sanity + def test_register_decorator_invalid_list_element(self, valid_instances): + """Test register_decorator with invalid elements in name list.""" + registry_class, _ = valid_instances + class TestClass: + pass -# ===== Auto-Discovery Tests ===== + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", 123]) + @pytest.mark.sanity + def test_register_decorator_empty_string_name(self, valid_instances): + """Test register_decorator with empty string name.""" + registry_class, _ = valid_instances -@pytest.mark.smoke -def test_auto_discovery_registry_initialization(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" + class TestClass: + pass - assert TestAutoRegistry.registry is None - assert TestAutoRegistry.registry_populated is False - assert TestAutoRegistry.auto_package == "test_package.modules" - assert TestAutoRegistry.registry_auto_discovery is True + registry_class.register_decorator(TestClass, name="") + assert "" in registry_class.registry + assert registry_class.registry[""] is TestClass + @pytest.mark.sanity + def test_register_decorator_none_in_list(self, valid_instances): + """Test register_decorator with None in name list.""" + registry_class, _ = valid_instances -@pytest.mark.smoke -def test_auto_populate_registry(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" + class TestClass: + pass - with mock.patch.object( - TestAutoRegistry, "auto_import_package_modules" - ) as mock_import: - TestAutoRegistry.auto_populate_registry() - mock_import.assert_called_once() - assert TestAutoRegistry.registry_populated is True + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", None]) + + @pytest.mark.smoke + def test_is_registered_empty_registry(self, valid_instances): + """Test is_registered with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.is_registered("any_name") + assert result is False + + @pytest.mark.smoke + def test_get_registered_object_empty_registry(self, valid_instances): + """Test get_registered_object with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.get_registered_object("any_name") + assert result is None + + @pytest.mark.regression + def test_auto_registry_integration(self): + """Test complete auto-discovery workflow with mocked imports.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as mock_walk, + mock.patch("importlib.import_module") as mock_import, + ): + mock_package = mock.MagicMock() + mock_package.__path__ = ["test_package/modules"] + mock_package.__name__ = "test_package.modules" + + def import_module(name: str): + if name == "test_package.modules": + return mock_package + elif name == "test_package.modules.module1": + module = mock.MagicMock() + module.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + return module + else: + raise ImportError(f"No module named {name}") + + def walk_packages(package_path, package_name): + if package_name == "test_package.modules.": + return [(None, "test_package.modules.module1", False)] + else: + raise ValueError(f"Unknown package: {package_name}") + + mock_walk.side_effect = walk_packages + mock_import.side_effect = import_module + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None # type: ignore[misc] + assert "Module1Class" in TestAutoRegistry.registry # type: ignore[misc] + + @pytest.mark.smoke + def test_register_preserves_class_metadata(self): + """Test that registered classes retain docs, types, and methods.""" + + class TestRegistry(RegistryMixin): + pass - # Second call should not trigger another import since already populated - TestAutoRegistry.auto_populate_registry() - mock_import.assert_called_once() - - -@pytest.mark.sanity -def test_auto_populate_registry_disabled(): - class TestDisabledAutoRegistry(ClassRegistryMixin): - # registry_auto_discovery is False by default - auto_package = "test_package.modules" - - with pytest.raises(ValueError) as exc_info: - TestDisabledAutoRegistry.auto_populate_registry() - - assert "registry_auto_discovery is set to False" in str(exc_info.value) - - -@pytest.mark.sanity -def test_auto_registered_classes(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" - - with mock.patch.object(TestAutoRegistry, "auto_populate_registry") as mock_populate: - # Mock the registry content - TestAutoRegistry.registry = {"Class1": "class1", "Class2": "class2"} # type: ignore[dict-item] - classes = TestAutoRegistry.registered_classes() - mock_populate.assert_called_once() - assert classes == ("class1", "class2") - - -@pytest.mark.regression -def test_auto_registry_integration(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" - - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package = mock.MagicMock() - mock_package.__path__ = ["test_package/modules"] - mock_package.__name__ = "test_package.modules" - - def import_module(name: str): - if name == "test_package.modules": - return mock_package - elif name == "test_package.modules.module1": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module1" - - class Module1Class: - pass - - TestAutoRegistry.register_decorator(Module1Class, "Module1Class") - return module - else: - raise ImportError(f"No module named {name}") - - def walk_packages(package_path, package_name): - if package_name == "test_package.modules.": - return [(None, "test_package.modules.module1", False)] - else: - raise ValueError(f"Unknown package: {package_name}") - - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - - classes = TestAutoRegistry.registered_classes() - assert len(classes) == 1 - assert TestAutoRegistry.registry_populated is True - assert TestAutoRegistry.registry is not None - assert "Module1Class" in TestAutoRegistry.registry - - -@pytest.mark.regression -def test_auto_registry_with_multiple_packages(): - class TestMultiPackageRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = ("package1", "package2") - - with mock.patch.object( - TestMultiPackageRegistry, "auto_import_package_modules" - ) as mock_import: - # Mock the registry to avoid ValueError when getting registered_classes - TestMultiPackageRegistry.registry = {} - TestMultiPackageRegistry.registered_classes() - mock_import.assert_called_once() - assert TestMultiPackageRegistry.registry_populated is True - - -@pytest.mark.regression -def test_auto_registry_no_package(): - class TestNoPackageRegistry(ClassRegistryMixin): - registry_auto_discovery = True - # No auto_package defined - - with mock.patch.object( - TestNoPackageRegistry, - "auto_import_package_modules", - side_effect=ValueError("auto_package must be set"), - ) as mock_import: - with pytest.raises(ValueError) as exc_info: - TestNoPackageRegistry.auto_populate_registry() - - mock_import.assert_called_once() - assert "auto_package must be set" in str(exc_info.value) + @TestRegistry.register("documented_class") + class DocumentedClass: + """This is a documented class with methods and type hints.""" + + def __init__(self, value: int) -> None: + """Initialize with a value. + + :param value: An integer value + """ + self.value = value + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedClass: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedClass instance + """ + return cls(int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + # Check that the class was registered + assert TestRegistry.is_registered("documented_class") + registered_class = TestRegistry.get_registered_object("documented_class") + assert registered_class is DocumentedClass + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented class with methods" in registered_class.__doc__ + assert registered_class.__init__.__doc__ is not None + assert "Initialize with a value" in registered_class.__init__.__doc__ + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) + + # Check that methods are callable and work correctly + instance = registered_class(42) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + instance2 = registered_class.from_string("123") + assert instance2.get_value() == 123 + assert registered_class.validate_value(10) is True + assert registered_class.validate_value(-5) is False + + # Check that type annotations are preserved (if accessible) + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(registered_class.__init__) + assert "value" in annotations + assert annotations["value"] in (int, "int") + return_ann = annotations.get("return") + assert ( + return_ann == "None" + or return_ann is None + or return_ann is type(None) + ) + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass + + # Check that the class name is preserved + assert registered_class.__name__ == "DocumentedClass" + assert registered_class.__qualname__.endswith("DocumentedClass")