From 137d0854326cd52917dd602b75b6a3a68d94a0a7 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 29 Aug 2025 07:17:32 -0400 Subject: [PATCH 1/8] Initial change to start main PR for converters refactor Signed-off-by: Mark Kurtz --- src/speculators/convert/eagle/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/speculators/convert/eagle/__init__.py b/src/speculators/convert/eagle/__init__.py index 64777b87..591498a7 100644 --- a/src/speculators/convert/eagle/__init__.py +++ b/src/speculators/convert/eagle/__init__.py @@ -1,5 +1,5 @@ """ -Eagle checkpoint conversion utilities. +Eagle-1 and Eagle-3 checkpoint conversion utilities. """ from speculators.convert.eagle.eagle_converter import EagleConverter From a7e62c9732be3c7becee46fee6f9e9dd38a7494e Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Fri, 29 Aug 2025 17:20:09 -0400 Subject: [PATCH 2/8] Update base utility files for enablement of the class based converters Signed-off-by: Mark Kurtz --- pyproject.toml | 1 + src/speculators/model.py | 6 +- src/speculators/utils/__init__.py | 4 +- src/speculators/utils/auto_importer.py | 74 +- src/speculators/utils/pydantic_utils.py | 251 +++---- src/speculators/utils/registry.py | 286 ++++---- tests/unit/models/test_eagle_config.py | 6 +- tests/unit/models/test_eagle_model.py | 6 +- tests/unit/test_model.py | 6 +- tests/unit/utils/test_auto_importer.py | 423 +++++++----- tests/unit/utils/test_pydantic_utils.py | 868 +++++++++++++++++------- tests/unit/utils/test_registry.py | 803 +++++++++++++++------- 12 files changed, 1743 insertions(+), 991 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2615f138..96499e3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,7 @@ dependencies = [ "click", "datasets", "deepspeed", + "eval_type_backport", "httpx[http2]", "huggingface-hub", "loguru", diff --git a/src/speculators/model.py b/src/speculators/model.py index d8f6bca2..ebe79554 100644 --- a/src/speculators/model.py +++ b/src/speculators/model.py @@ -37,10 +37,12 @@ from transformers.generation.utils import GenerateOutput from speculators.config import SpeculatorModelConfig -from speculators.utils import ClassRegistryMixin +from speculators.utils import RegistryMixin -class SpeculatorModel(ClassRegistryMixin, PreTrainedModel, GenerationMixin): # type: ignore[misc] +class SpeculatorModel( # type: ignore[misc] + RegistryMixin[type["SpeculatorModel"]], PreTrainedModel, GenerationMixin +): """ Abstract base class for all speculator models in the Speculators library. diff --git a/src/speculators/utils/__init__.py b/src/speculators/utils/__init__.py index ebe8d140..edf0889b 100644 --- a/src/speculators/utils/__init__.py +++ b/src/speculators/utils/__init__.py @@ -1,10 +1,10 @@ from .auto_importer import AutoImporterMixin from .pydantic_utils import PydanticClassRegistryMixin, ReloadableBaseModel -from .registry import ClassRegistryMixin +from .registry import RegistryMixin __all__ = [ "AutoImporterMixin", - "ClassRegistryMixin", "PydanticClassRegistryMixin", + "RegistryMixin", "ReloadableBaseModel", ] diff --git a/src/speculators/utils/auto_importer.py b/src/speculators/utils/auto_importer.py index 3b3240d3..88e11c85 100644 --- a/src/speculators/utils/auto_importer.py +++ b/src/speculators/utils/auto_importer.py @@ -1,64 +1,58 @@ """ Automatic module importing utilities for dynamic class discovery. -This module provides a mixin class for automatic module importing within a package, +This module provides a mixin class for automatic module importing within packages, enabling dynamic discovery of classes and implementations without explicit imports. -It is particularly useful for auto-registering classes in a registry pattern where -subclasses need to be discoverable at runtime. - -The AutoImporterMixin can be combined with registration mechanisms to create -extensible systems where new implementations are automatically discovered and -registered when they are placed in the correct package structure. - -Classes: - - AutoImporterMixin: A mixin class that provides functionality to automatically - import all modules within a specified package or list of packa +It is designed for registry patterns where subclasses need to be discoverable at +runtime, creating extensible systems where new implementations are automatically +discovered when placed in the correct package structure. """ +from __future__ import annotations + import importlib import pkgutil import sys -from typing import ClassVar, Optional, Union +from typing import ClassVar __all__ = ["AutoImporterMixin"] class AutoImporterMixin: """ - A mixin class that provides functionality to automatically import all modules - within a specified package or list of packages. - - This mixin is designed to be used with class registration mechanisms to enable - automatic discovery and registration of classes without explicit imports. When - a class inherits from AutoImporterMixin, it can define the package(s) to scan - for modules by setting the `auto_package` class variable. - - Usage Example: - ```python - from speculators.utils import AutoImporterMixin - class MyRegistry(AutoImporterMixin): - auto_package = "my_package.implementations" - - MyRegistry.auto_import_package_modules() - ``` - - :cvar auto_package: The package name or tuple of names to import modules from. - :cvar auto_ignore_modules: Optional tuple of module names to ignore during import. - :cvar auto_imported_modules: List tracking which modules have been imported. + Mixin class for automatic module importing within packages. + + This mixin enables dynamic discovery of classes and implementations by + automatically importing all modules within specified packages. It is designed + for use with class registration mechanisms to enable automatic discovery and + registration of classes when they are placed in the correct package structure. + + Example: + :: + from speculators.utils import AutoImporterMixin + + class MyRegistry(AutoImporterMixin): + auto_package = "my_package.implementations" + + MyRegistry.auto_import_package_modules() + + :cvar auto_package: Package name or tuple of package names to import modules from + :cvar auto_ignore_modules: Module names to ignore during import + :cvar auto_imported_modules: List tracking which modules have been imported """ - auto_package: ClassVar[Optional[Union[str, tuple[str, ...]]]] = None - auto_ignore_modules: ClassVar[Optional[tuple[str, ...]]] = None - auto_imported_modules: ClassVar[Optional[list]] = None + auto_package: ClassVar[str | tuple[str, ...] | None] = None + auto_ignore_modules: ClassVar[tuple[str, ...] | None] = None + auto_imported_modules: ClassVar[list[str] | None] = None @classmethod - def auto_import_package_modules(cls): + def auto_import_package_modules(cls) -> None: """ - Automatically imports all modules within the specified package(s). + Automatically import all modules within the specified package(s). - This method scans the package(s) defined in the `auto_package` class variable - and imports all modules found, tracking them in `auto_imported_modules`. It - skips packages (directories) and any modules listed in `auto_ignore_modules`. + Scans the package(s) defined in `auto_package` and imports all modules found, + tracking them in `auto_imported_modules`. Skips packages and any modules + listed in `auto_ignore_modules`. :raises ValueError: If the `auto_package` class variable is not set """ diff --git a/src/speculators/utils/pydantic_utils.py b/src/speculators/utils/pydantic_utils.py index 01816157..6d074796 100644 --- a/src/speculators/utils/pydantic_utils.py +++ b/src/speculators/utils/pydantic_utils.py @@ -1,113 +1,110 @@ """ -General pydantic utilities for Speculators. +Pydantic utilities for polymorphic model serialization and registry integration. -This module provides integration between Pydantic and the Speculators library, -enabling things like polymorphic serialization and deserialization of Pydantic -models using a discriminator field and registry. - -Classes: - PydanticClassRegistryMixin: A mixin that combines Pydantic models with the - ClassRegistryMixin to support polymorphic model instantiation based on - a discriminator field +Provides integration between Pydantic and the registry system, enabling +polymorphic serialization and deserialization of Pydantic models using +a discriminator field and dynamic class registry. Includes base model classes +with standardized configurations and generic status breakdown models for +structured result organization. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, ClassVar, Optional +from typing import Any, ClassVar, Generic, TypeVar from pydantic import BaseModel, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema -from speculators.utils.registry import ClassRegistryMixin +from speculators.utils.registry import RegistryMixin + +__all__ = [ + "BaseModelT", + "PydanticClassRegistryMixin", + "RegisterClassT", + "ReloadableBaseModel", +] + -__all__ = ["PydanticClassRegistryMixin", "ReloadableBaseModel"] +BaseModelT = TypeVar("BaseModelT", bound=BaseModel) +RegisterClassT = TypeVar("RegisterClassT", bound=type[BaseModel]) class ReloadableBaseModel(BaseModel): + """ + Base Pydantic model with schema reloading capabilities. + + Provides dynamic schema rebuilding functionality for models that need to + update their validation schemas at runtime, particularly useful when + working with registry-based polymorphic models where new types are + registered after initial class definition. + """ + @classmethod - def reload_schema(cls): + def reload_schema(cls) -> None: """ - Reloads the schema for the class, ensuring that the registry is populated - and that the schema is up-to-date. + Reload the class schema with updated registry information. - This method is useful when the registry has been modified or when the - class needs to be re-validated with the latest schema. + Forces a complete rebuild of the Pydantic model schema to incorporate + any changes made to associated registries or validation rules. """ cls.model_rebuild(force=True) -class PydanticClassRegistryMixin(ReloadableBaseModel, ABC, ClassRegistryMixin): +class PydanticClassRegistryMixin( + ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] +): """ - A mixin class that integrates Pydantic models with the ClassRegistryMixin to enable - polymorphic serialization and deserialization based on a discriminator field. - - This mixin allows Pydantic models to be registered in a registry and dynamically - instantiated based on a discriminator field in the input data. - It overrides Pydantic's validation system to correctly instantiate the appropriate - subclass based on the discriminator value and the name of the registered classes. - - The mixin is particularly useful for implementing base registry classes that need to - support multiple implementations, such as different token proposal methods or - speculative decoding algorithms. - - Usage Example: - ```python - from typing import ClassVar - from pydantic import BaseModel, Field - from speculators.utils import PydanticClassRegistryMixin - - class BaseConfig(PydanticClassRegistryMixin): - @classmethod - def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: - if cls.__name__ == "BaseConfig": - return cls - return BaseConfig - - schema_discriminator: ClassVar[str] = "config_type" - config_type: str = Field(description="The type of configuration") - - @BaseConfig.register("config_a") - class ConfigA(BaseConfig): - config_type: str = "config_a" - value_a: str = Field(description="A value specific to ConfigA") - - @BaseConfig.register("config_b") - class ConfigB(BaseConfig): - config_type: str = "config_b" - value_b: int = Field(description="A value specific to ConfigB") - - BaseConfig.reload_schema() # Ensures the schema is up-to-date with registry - - # Dynamic instantiation based on config_type - config_data = {"config_type": "config_a", "value_a": "test"} - config = BaseConfig.model_validate(config_data) # Returns ConfigA instance - print(config) - dump_data = config.model_dump() # Dumps the data to a dictionary - print(dump_data) # Output: {'config_type': 'config_a', 'value_a': 'test'} - ``` - - :cvar schema_discriminator: The field name used as the discriminator in the JSON - schema. Default is "model_type". - :cvar registry: A dictionary mapping discriminator values to pydantic model classes. + Polymorphic Pydantic model mixin enabling registry-based dynamic instantiation. + + Integrates Pydantic validation with the registry system to enable polymorphic + serialization and deserialization based on a discriminator field. Automatically + instantiates the correct subclass during validation based on registry mappings, + providing a foundation for extensible plugin-style architectures. + + Example: + :: + from speculators.utils import PydanticClassRegistryMixin + + class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): + schema_discriminator: ClassVar[str] = "config_type" + config_type: str = Field(description="Configuration type identifier") + + @classmethod + def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: + return BaseConfig + + @BaseConfig.register("database") + class DatabaseConfig(BaseConfig): + config_type: str = "database" + connection_string: str = Field(description="Database connection string") + + # Dynamic instantiation based on discriminator + config = BaseConfig.model_validate({ + "config_type": "database", + "connection_string": "postgresql://localhost:5432/db" + }) + + :cvar schema_discriminator: Field name used for polymorphic type discrimination """ schema_discriminator: ClassVar[str] = "model_type" - registry: ClassVar[Optional[dict[str, BaseModel]]] = None # type: ignore[assignment] @classmethod - def register_decorator( - cls, clazz: type[BaseModel], name: Optional[str] = None - ) -> type[BaseModel]: + def register_decorator( # type: ignore[override] + cls, clazz: RegisterClassT, name: str | list[str] | None = None + ) -> RegisterClassT: """ - Registers a Pydantic model class with the registry. + Register a Pydantic model class with type validation and schema reload. - This method extends the ClassRegistryMixin.register_decorator method by adding - a type check to ensure only Pydantic BaseModel subclasses can be registered. + Validates that the class is a proper Pydantic BaseModel subclass before + registering it in the class registry. Automatically triggers schema + reload to incorporate the new type into polymorphic validation. - :param clazz: The Pydantic model class to register - :param name: Optional name to register the class under. If None, the class name - is used as the registry key. - :return: The registered class. - :raises TypeError: If clazz is not a subclass of Pydantic BaseModel + :param clazz: Pydantic model class to register in the polymorphic hierarchy + :param name: Registry identifier for the class. Uses class name if None + :return: The registered class unchanged for decorator chaining + :raises TypeError: If clazz is not a Pydantic BaseModel subclass """ if not issubclass(clazz, BaseModel): raise TypeError( @@ -115,24 +112,25 @@ def register_decorator( "Pydantic BaseModel" ) - return super().register_decorator(clazz, name=name) + super().register_decorator(clazz, name=name) + cls.reload_schema() + + return clazz @classmethod def __get_pydantic_core_schema__( cls, source_type: Any, handler: GetCoreSchemaHandler ) -> CoreSchema: """ - Customizes the Pydantic schema for polymorphic model validation. + Generate polymorphic validation schema for dynamic type instantiation. - This method is part of Pydantic's validation system and is called during - schema generation. It checks if the source_type matches the base type of the - polymorphic model. If it does, it generates a tagged union schema that allows - for dynamic instantiation of the appropriate subclass based on the discriminator - field. + Creates a tagged union schema that enables Pydantic to automatically + instantiate the correct subclass based on the discriminator field value. + Falls back to base schema generation when no registry is available. - :param source_type: The type for which the schema is being generated - :param handler: Handler for generating core schema - :return: A CoreSchema object with the custom validator if appropriate + :param source_type: Type being processed for schema generation + :param handler: Pydantic core schema generation handler + :return: Tagged union schema for polymorphic validation or base schema """ if source_type == cls.__pydantic_schema_base_type__(): if not cls.registry: @@ -151,22 +149,14 @@ def __get_pydantic_core_schema__( @classmethod @abstractmethod - def __pydantic_schema_base_type__(cls) -> type[Any]: + def __pydantic_schema_base_type__(cls) -> type[BaseModelT]: """ - Abstract method that must be implemented by subclasses to define the base type. + Define the base type for polymorphic validation hierarchy. - This method should return the base class type that serves as the root of the - polymorphic hierarchy. The returned type is used to determine when to apply - the custom validation logic for polymorphic instantiation. + Must be implemented by subclasses to specify which type serves as the + root of the polymorphic hierarchy for schema generation and validation. - Example implementation: - ```python - @classmethod - def __pydantic_schema_base_type__(cls) -> type["MyBaseClass"]: - return MyBaseClass - ``` - - :return: The base class type for polymorphic validation + :return: Base class type for the polymorphic model hierarchy """ ... @@ -175,37 +165,52 @@ def __pydantic_generate_base_schema__( cls, handler: GetCoreSchemaHandler ) -> CoreSchema: """ - Generates the base schema for the polymorphic model. - - This method is used by the Pydantic validation system to create the core - schema for the base class. By default, it returns an any_schema which accepts - any valid input, relying on the validator function to perform the actual - validation and model instantiation. + Generate fallback schema for polymorphic models without registry. - Subclasses can override this method to provide a more specific base schema - if needed. + Provides a base schema that accepts any valid input when no registry + is available for polymorphic validation. Used as fallback during + schema generation when the registry has not been populated. - :param handler: Handler for generating core schema - :return: A CoreSchema object representing the base schema + :param handler: Pydantic core schema generation handler + :return: Base CoreSchema that accepts any valid input """ return core_schema.any_schema() @classmethod def auto_populate_registry(cls) -> bool: """ - Ensures that all registered classes in the registry are properly initialized. + Initialize registry with auto-discovery and reload validation schema. - This method is called automatically by Pydantic when the model is instantiated - or validated. It ensures that all classes in the registry are loaded and ready - for use. + Triggers automatic population of the class registry through the parent + RegistryMixin functionality and ensures the Pydantic validation schema + is updated to include all discovered types for polymorphic validation. - This is particularly useful for ensuring that all subclasses are registered - before any validation occurs. - - :return: True if the registry was populated, False if it was already populated - :raises ValueError: If called when registry_auto_discovery is False + :return: True if registry was populated, False if already populated + :raises ValueError: If called when registry_auto_discovery is disabled """ populated = super().auto_populate_registry() cls.reload_schema() return populated + + @classmethod + def registered_classes(cls) -> tuple[type[BaseModelT], ...]: + """ + Get all registered pydantic classes from the registry. + + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. + + :return: Tuple of all registered classes including auto-discovered ones + :raises ValueError: If called before any objects have been registered + """ + if cls.registry_auto_discovery: + cls.auto_populate_registry() + + if cls.registry is None: + raise ValueError( + "ClassRegistryMixin.registered_classes() must be called after " + "registering classes with ClassRegistryMixin.register()." + ) + + return tuple(cls.registry.values()) diff --git a/src/speculators/utils/registry.py b/src/speculators/utils/registry.py index 21994b91..34b74549 100644 --- a/src/speculators/utils/registry.py +++ b/src/speculators/utils/registry.py @@ -1,189 +1,149 @@ """ -Registry system for classes in the Speculators library. +Registry system for dynamic object registration and discovery. -This module provides a flexible class registration and discovery system used -throughout the Speculators library. It enables automatic registration of classes -and discovery of implementations through class decorators and module imports. - -The registry system is used to track different implementations of token proposal -methods, speculative decoding algorithms, and speculator models, allowing for -dynamic discovery and instantiation based on configuration parameters. - -Classes: - ClassRegistryMixin: Base mixin for creating class registries with decorators - and optional auto-discovery capabilities through registry_auto_discovery flag. - AutoClassRegistryMixin: A backward-compatible version of ClassRegistryMixin with - auto-discovery enabled by default +Provides a flexible object registration system with optional auto-discovery +capabilities through decorators and module imports. Enables dynamic discovery +and instantiation of implementations based on configuration parameters, supporting +both manual registration and automatic package-based discovery for extensible +plugin architectures. """ -from typing import Any, Callable, ClassVar, Optional +from __future__ import annotations + +from typing import Callable, ClassVar, Generic, TypeVar, cast from speculators.utils.auto_importer import AutoImporterMixin -__all__ = ["ClassRegistryMixin"] +__all__ = ["RegisterT", "RegistryMixin", "RegistryObjT"] -class ClassRegistryMixin(AutoImporterMixin): - """ - A mixin class that provides a registration system for tracking class - implementations with optional auto-discovery capabilities. +RegistryObjT = TypeVar("RegistryObjT") +"""Generic type variable for objects managed by the registry system.""" +RegisterT = TypeVar("RegisterT") +"""Generic type variable for the args and return values within the registry.""" - This mixin allows classes to maintain a registry of subclasses that can be - dynamically discovered and instantiated. Classes that inherit from this mixin - can use the @register decorator to add themselves to the registry. - The registry is class-specific, meaning each class that inherits from this mixin - will have its own separate registry of implementations. +class RegistryMixin(Generic[RegistryObjT], AutoImporterMixin): + """ + Generic mixin for creating object registries with optional auto-discovery. - The mixin can also be configured to automatically discover and register classes - from specified packages by setting registry_auto_discovery=True and defining - an auto_package class variable to specify which package(s) should be automatically - imported to discover implementations. + Enables classes to maintain separate registries of objects that can be dynamically + discovered and instantiated through decorators and module imports. Supports both + manual registration via decorators and automatic discovery through package scanning + for extensible plugin architectures. Example: - ```python - class BaseAlgorithm(ClassRegistryMixin): - pass + :: + class BaseAlgorithm(RegistryMixin): + pass - @BaseAlgorithm.register() - class ConcreteAlgorithm(BaseAlgorithm): - pass + @BaseAlgorithm.register() + class ConcreteAlgorithm(BaseAlgorithm): + pass - @BaseAlgorithm.register("custom_name") - class AnotherAlgorithm(BaseAlgorithm): - pass + @BaseAlgorithm.register("custom_name") + class AnotherAlgorithm(BaseAlgorithm): + pass - # Get all registered algorithm implementations - algorithms = BaseAlgorithm.registered_classes() - ``` + # Get all registered implementations + algorithms = BaseAlgorithm.registered_objects() Example with auto-discovery: - ```python - class TokenProposal(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "speculators.proposals" - - # This will automatically import all modules in the proposals package - # and register any classes decorated with @TokenProposal.register() - proposals = TokenProposal.registered_classes() - ``` - - :cvar registry: A dictionary mapping class names to classes that have been - registered to the extending subclass through the @subclass.register() decorator - :cvar registry_auto_discovery: A flag that enables automatic discovery and import of - modules from the auto_package when set to True. Default is False. - :cvar registry_populated: A flag that tracks whether the registry has been - populated with classes from the specified package(s). + :: + class TokenProposal(RegistryMixin): + registry_auto_discovery = True + auto_package = "mypackage.proposals" + + # Automatically imports and registers decorated objects + proposals = TokenProposal.registered_objects() + + :cvar registry: Dictionary mapping names to registered objects + :cvar registry_auto_discovery: Enable automatic package-based discovery + :cvar registry_populated: Track whether auto-discovery has completed """ - registry: ClassVar[Optional[dict[str, type[Any]]]] = None + registry: ClassVar[dict[str, RegistryObjT] | None] = None # type: ignore[misc] registry_auto_discovery: ClassVar[bool] = False registry_populated: ClassVar[bool] = False @classmethod - def register(cls, name: Optional[str] = None) -> Callable[[type[Any]], type[Any]]: + def register( + cls, name: str | list[str] | None = None + ) -> Callable[[RegisterT], RegisterT]: """ - An invoked class decorator that registers that class with the registry under - either the provided name or the class name if no name is provided. - - Example: - ```python - @ClassRegistryMixin.register() - class ExampleClass: - ... - - @ClassRegistryMixin.register("custom_name") - class AnotherExampleClass: - ... - ``` - - :param name: Optional name to register the class under. If None, the class name - is used as the registry key. - :return: A decorator function that registers the decorated class. - :raises ValueError: If name is provided but is not a string. + Decorator for registering objects with the registry. + + :param name: Optional name(s) to register the object under. + If None, uses the object's __name__ attribute + :return: Decorator function that registers the decorated object + :raises ValueError: If name is not a string, list of strings, or None """ - if name is not None and not isinstance(name, str): - raise ValueError( - "ClassRegistryMixin.register() name must be a string or None. " - f"Got {name}." - ) - return lambda subclass: cls.register_decorator(subclass, name=name) + def _decorator(obj: RegisterT) -> RegisterT: + cls.register_decorator(obj, name=name) + return obj + + return _decorator @classmethod def register_decorator( - cls, clazz: type[Any], name: Optional[str] = None - ) -> type[Any]: - """ - A non-invoked class decorator that registers the class with the registry. - If passed through a lambda, then name can be passed in as well. - Otherwise, the only argument is the decorated class. - - Example: - ```python - @ClassRegistryMixin.register_decorator - class ExampleClass: - ... - ``` - - :param clazz: The class to register - :param name: Optional name to register the class under. If None, the class name - is used as the registry key. - :return: The registered class. - :raises TypeError: If the decorator is used incorrectly or if the class is not - a type. - :raises ValueError: If the class is already registered or if name is provided - but is not a string. + cls, obj: RegisterT, name: str | list[str] | None = None + ) -> RegisterT: """ + Register an object directly with the registry. - if not isinstance(clazz, type): - raise TypeError( - "ClassRegistryMixin.register_decorator must be used as a class " - "decorator and without invocation." - f"Got improper clazz arg {clazz}." - ) + :param obj: The object to register + :param name: Optional name(s) to register the object under. + If None, uses the object's __name__ attribute + :return: The registered object + :raises ValueError: If the object is already registered or name is invalid + """ - if not name: - name = clazz.__name__ - elif not isinstance(name, str): + if name is None: + name = obj.__name__ if hasattr(obj, "__name__") else str(obj) + elif not isinstance(name, (str, list)): raise ValueError( - "ClassRegistryMixin.register_decorator must be used as a class " - "decorator and without invocation. " - f"Got imporoper name arg {name}." + "RegistryMixin.register_decorator name must be a string or " + f"an iterable of strings. Got {name}." ) if cls.registry is None: cls.registry = {} - if name in cls.registry: - raise ValueError( - f"ClassRegistryMixin.register_decorator cannot register a class " - f"{clazz} with the name {name} because it is already registered." - ) + names = [name] if isinstance(name, str) else list(name) + + for register_name in names: + if not isinstance(register_name, str): + raise ValueError( + "RegistryMixin.register_decorator name must be a string or " + f"a list of strings. Got {register_name}." + ) + + if register_name in cls.registry: + raise ValueError( + f"RegistryMixin.register_decorator cannot register an object " + f"{obj} with the name {register_name} because it is already " + "registered." + ) - cls.registry[name] = clazz + cls.registry[register_name] = cast("RegistryObjT", obj) - return clazz + return obj @classmethod def auto_populate_registry(cls) -> bool: """ - Ensures that all modules in the specified auto_package are imported. + Import and register all modules from the auto_package. - This method is called automatically by registered_classes when - registry_auto_discovery==True to ensure that all available implementations are - discovered and registered before returning the list of registered classes. + Automatically called by registered_objects when registry_auto_discovery is True + to ensure all available implementations are discovered. - To enable auto-discovery: - 1. Set registry_auto_discovery = True on the class - 2. Define an auto_package class variable with the package path to import - - :return: True if the registry was populated, False if it was already populated. + :return: True if registry was populated, False if already populated :raises ValueError: If called when registry_auto_discovery is False """ if not cls.registry_auto_discovery: raise ValueError( - "ClassRegistryMixin.auto_populate_registry() cannot be called " + "RegistryMixin.auto_populate_registry() cannot be called " "because registry_auto_discovery is set to False. " "Set registry_auto_discovery to True to enable auto-discovery." ) @@ -197,26 +157,62 @@ def auto_populate_registry(cls) -> bool: return True @classmethod - def registered_classes(cls) -> tuple[type[Any], ...]: + def registered_objects(cls) -> tuple[RegistryObjT, ...]: """ - Returns a tuple of all classes that have been registered with this registry. + Get all registered objects from the registry. - If registry_auto_discovery is True, this method will first call - auto_populate_registry to ensure that all available implementations from - the specified auto_package are discovered and registered before returning - the list. + Automatically triggers auto-discovery if registry_auto_discovery is enabled + to ensure all available implementations are included. - :return: A tuple containing all registered class implementations, including - those discovered through auto-importing when registry_auto_discovery==True. - :raises ValueError: If called before any classes have been registered. + :return: Tuple of all registered objects including auto-discovered ones + :raises ValueError: If called before any objects have been registered """ if cls.registry_auto_discovery: cls.auto_populate_registry() if cls.registry is None: raise ValueError( - "ClassRegistryMixin.registered_classes() must be called after " - "registering classes with ClassRegistryMixin.register()." + "RegistryMixin.registered_objects() must be called after " + "registering objects with RegistryMixin.register()." ) return tuple(cls.registry.values()) + + @classmethod + def is_registered(cls, name: str) -> bool: + """ + Check if an object is registered under the given name. + It matches first by exact name, then by str.lower(). + + :param name: The name to check for registration. + :return: True if the object is registered, False otherwise. + """ + if cls.registry is None: + return False + + return name in cls.registry or name.lower() in [ + key.lower() for key in cls.registry + ] + + @classmethod + def get_registered_object(cls, name: str) -> RegistryObjT | None: + """ + Get a registered object by its name. It matches first by exact name, + then by str.lower(). + + :param name: The name of the registered object. + :return: The registered object if found, None otherwise. + """ + if cls.registry is None: + return None + + if name in cls.registry: + return cls.registry[name] + + lower_key_map = {key.lower(): key for key in cls.registry} + + return ( + cls.registry[lower_key_map[name.lower()]] + if name.lower() in lower_key_map + else None + ) diff --git a/tests/unit/models/test_eagle_config.py b/tests/unit/models/test_eagle_config.py index f89fd7ec..745dfc70 100644 --- a/tests/unit/models/test_eagle_config.py +++ b/tests/unit/models/test_eagle_config.py @@ -324,9 +324,9 @@ def test_eagle_speculator_config_auto_registry(): assert "EagleSpeculatorConfig" in class_names # Verify registry key mapping - assert SpeculatorModelConfig.registry is not None - assert "eagle" in SpeculatorModelConfig.registry - assert SpeculatorModelConfig.registry["eagle"] == EagleSpeculatorConfig + assert SpeculatorModelConfig.registry is not None # type: ignore[misc] + assert "eagle" in SpeculatorModelConfig.registry # type: ignore[misc] + assert SpeculatorModelConfig.registry["eagle"] == EagleSpeculatorConfig # type: ignore[misc] @pytest.mark.smoke diff --git a/tests/unit/models/test_eagle_model.py b/tests/unit/models/test_eagle_model.py index 4a1f73d0..afde0440 100644 --- a/tests/unit/models/test_eagle_model.py +++ b/tests/unit/models/test_eagle_model.py @@ -215,9 +215,9 @@ def test_eagle_speculator_class_attributes(): @pytest.mark.smoke def test_eagle_speculator_registry(): - assert SpeculatorModel.registry is not None - assert "eagle" in SpeculatorModel.registry - assert SpeculatorModel.registry["eagle"] == EagleSpeculator + assert SpeculatorModel.registry is not None # type: ignore[misc] + assert "eagle" in SpeculatorModel.registry # type: ignore[misc] + assert SpeculatorModel.registry["eagle"] == EagleSpeculator # type: ignore[misc] @pytest.mark.smoke diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index 7fb8f441..acca4bca 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -96,9 +96,9 @@ def test_speculator_model_class_attributes(): @pytest.mark.smoke def test_speculator_model_registry_contains_test_model(): - assert SpeculatorModel.registry is not None - assert "test_speculator" in SpeculatorModel.registry - assert SpeculatorModel.registry["test_speculator"] == SpeculatorTestModel + assert SpeculatorModel.registry is not None # type: ignore[misc] + assert "test_speculator" in SpeculatorModel.registry # type: ignore[misc] + assert SpeculatorModel.registry["test_speculator"] == SpeculatorTestModel # type: ignore[misc] @pytest.mark.smoke diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py index 77640a8d..ee3d68f9 100644 --- a/tests/unit/utils/test_auto_importer.py +++ b/tests/unit/utils/test_auto_importer.py @@ -1,196 +1,269 @@ """ -Unit tests for the auto_importer module in the Speculators library. +Unit tests for the auto_importer module. """ +from __future__ import annotations + from unittest import mock import pytest -from speculators.utils.auto_importer import AutoImporterMixin - -# ===== Basic Functionality Tests ===== - - -@pytest.mark.smoke -def test_auto_importer_initialization(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" - - assert AutoImporterMixin.auto_package is None - assert AutoImporterMixin.auto_ignore_modules is None - assert AutoImporterMixin.auto_imported_modules is None - - -@pytest.mark.smoke -def test_auto_importer_subclass_attributes(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" - - assert TestAutoImporterClass.auto_package == "test_package.modules" - assert TestAutoImporterClass.auto_ignore_modules is None - assert TestAutoImporterClass.auto_imported_modules is None - - -@pytest.mark.smoke -def test_no_package_raises_error(): - class TestAutoImporterClass(AutoImporterMixin): ... - - with pytest.raises(ValueError) as exc_info: - TestAutoImporterClass.auto_import_package_modules() +from speculators.utils import AutoImporterMixin + + +class TestAutoImporterMixin: + """Test suite for AutoImporterMixin functionality.""" + + @pytest.fixture( + params=[ + { + "auto_package": "test.package", + "auto_ignore_modules": None, + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module1", "test.package.module2"], + }, + { + "auto_package": ("test.package1", "test.package2"), + "auto_ignore_modules": None, + "modules": [ + ("test.package1.moduleA", False), + ("test.package2.moduleB", False), + ], + "expected_imports": ["test.package1.moduleA", "test.package2.moduleB"], + }, + { + "auto_package": "test.package", + "auto_ignore_modules": ("test.package.module1",), + "modules": [ + ("test.package.module1", False), + ("test.package.module2", False), + ], + "expected_imports": ["test.package.module2"], + }, + ], + ids=["single_package", "multiple_packages", "ignored_modules"], + ) + def valid_instances(self, request): + """Fixture providing test data for AutoImporterMixin subclasses.""" + config = request.param + + class TestClass(AutoImporterMixin): + auto_package = config["auto_package"] + auto_ignore_modules = config["auto_ignore_modules"] + + return TestClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test AutoImporterMixin class signatures and attributes.""" + assert hasattr(AutoImporterMixin, "auto_package") + assert hasattr(AutoImporterMixin, "auto_ignore_modules") + assert hasattr(AutoImporterMixin, "auto_imported_modules") + assert hasattr(AutoImporterMixin, "auto_import_package_modules") + assert callable(AutoImporterMixin.auto_import_package_modules) + + # Test default class variables + assert AutoImporterMixin.auto_package is None + assert AutoImporterMixin.auto_ignore_modules is None + assert AutoImporterMixin.auto_imported_modules is None + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test AutoImporterMixin subclass initialization.""" + test_class, config = valid_instances + assert issubclass(test_class, AutoImporterMixin) + assert test_class.auto_package == config["auto_package"] + assert test_class.auto_ignore_modules == config["auto_ignore_modules"] + assert test_class.auto_imported_modules is None + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test AutoImporterMixin with missing auto_package.""" + + class TestClass(AutoImporterMixin): + pass + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestClass.auto_import_package_modules() + + @pytest.mark.smoke + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules(self, mock_walk, mock_import, valid_instances): + """Test auto_import_package_modules core functionality.""" + test_class, config = valid_instances + + # Setup mocks based on config + packages = {} + modules = {} + + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + pkg_path = pkg.replace(".", "/") + packages[pkg] = MockHelper.create_mock_package(pkg, pkg_path) + else: + pkg = config["auto_package"] + packages[pkg] = MockHelper.create_mock_package(pkg, pkg.replace(".", "/")) + + for module_name, is_pkg in config["modules"]: + if not is_pkg: + modules[module_name] = MockHelper.create_mock_module(module_name) + + mock_import.side_effect = lambda name: {**packages, **modules}.get( + name, mock.MagicMock() + ) + + def walk_side_effect(path, prefix): + return [ + (None, module_name, is_pkg) + for module_name, is_pkg in config["modules"] + if module_name.startswith(prefix) + ] + + mock_walk.side_effect = walk_side_effect + + # Execute + test_class.auto_import_package_modules() + + # Verify + assert test_class.auto_imported_modules == config["expected_imports"] + + # Verify package imports + if isinstance(config["auto_package"], tuple): + for pkg in config["auto_package"]: + mock_import.assert_any_call(pkg) + else: + mock_import.assert_any_call(config["auto_package"]) + + # Verify expected module imports + for expected_module in config["expected_imports"]: + mock_import.assert_any_call(expected_module) + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_auto_import_package_modules_invalid(self, mock_walk, mock_import): + """Test auto_import_package_modules with invalid configurations.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Test import error handling + mock_import.side_effect = ImportError("Module not found") + + with pytest.raises(ImportError): + TestClass.auto_import_package_modules() + + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_packages(self, mock_walk, mock_import): + """Test that packages (is_pkg=True) are skipped.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") + + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] + + mock_walk.return_value = [ + (None, "test.package.subpackage", True), + (None, "test.package.module", False), + ] - assert "auto_package" in str(exc_info.value) - assert "must be set" in str(exc_info.value) + # Execute + TestClass.auto_import_package_modules() + + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + mock_import.assert_any_call("test.package.module") + # subpackage should not be imported + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.subpackage") + + @pytest.mark.sanity + @mock.patch("sys.modules", {"test.package.existing": mock.MagicMock()}) + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_skip_already_imported_modules(self, mock_walk, mock_import): + """Test that modules already in sys.modules are tracked but not re-imported.""" + + class TestClass(AutoImporterMixin): + auto_package = "test.package" + + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + }.get(name, mock.MagicMock()) + + mock_walk.return_value = [ + (None, "test.package.existing", False), + ] + # Execute + TestClass.auto_import_package_modules() -# ===== Module Import Tests ===== + # Verify + assert TestClass.auto_imported_modules == ["test.package.existing"] + mock_import.assert_called_once_with("test.package") + with pytest.raises(AssertionError): + mock_import.assert_any_call("test.package.existing") + @pytest.mark.sanity + @mock.patch("importlib.import_module") + @mock.patch("pkgutil.walk_packages") + def test_prevent_duplicate_module_imports(self, mock_walk, mock_import): + """Test that modules already in auto_imported_modules are not re-imported.""" -@pytest.mark.smoke -def test_single_package_import(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" + class TestClass(AutoImporterMixin): + auto_package = "test.package" - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package = mock.MagicMock() - mock_package.__path__ = ["test_package/modules"] - mock_package.__name__ = "test_package.modules" + # Setup mocks + mock_package = MockHelper.create_mock_package("test.package", "test/package") + mock_module = MockHelper.create_mock_module("test.package.module") - def import_module(name: str): - if name == "test_package.modules": - return mock_package - elif name == "test_package.modules.module1": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module1" - return module - elif name == "test_package.modules.module2": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module2" - return module - else: - raise ImportError(f"No module named {name}") + mock_import.side_effect = lambda name: { + "test.package": mock_package, + "test.package.module": mock_module, + }[name] - def walk_packages(package_path, package_name): - if package_name == "test_package.modules.": - return [ - (None, "test_package.modules.module1", False), - (None, "test_package.modules.module2", False), - ] - else: - raise ValueError(f"Unknown package: {package_name}") + mock_walk.return_value = [ + (None, "test.package.module", False), + (None, "test.package.module", False), + ] - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - TestAutoImporterClass.auto_import_package_modules() + # Execute + TestClass.auto_import_package_modules() - mock_import.assert_any_call("test_package.modules") - assert TestAutoImporterClass.auto_imported_modules == [ - "test_package.modules.module1", - "test_package.modules.module2", - ] + # Verify + assert TestClass.auto_imported_modules == ["test.package.module"] + assert mock_import.call_count == 2 # Package + module (not duplicate) -@pytest.mark.sanity -def test_multiple_package_import(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = ("test_package.modules1", "test_package.modules2") - - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package1 = mock.MagicMock() - mock_package1.__path__ = ["test_package/modules1"] - mock_package1.__name__ = "test_package.modules1" - - mock_package2 = mock.MagicMock() - mock_package2.__path__ = ["test_package/modules2"] - mock_package2.__name__ = "test_package.modules2" - - def import_module(name: str): - if name == "test_package.modules1": - return mock_package1 - elif name == "test_package.modules2": - return mock_package2 - elif name == "test_package.modules1.moduleA": - module = mock.MagicMock() - module.__name__ = "test_package.modules1.moduleA" - return module - elif name == "test_package.modules2.moduleB": - module = mock.MagicMock() - module.__name__ = "test_package.modules2.moduleB" - return module - else: - raise ImportError(f"No module named {name}") - - def walk_packages(package_path, package_name): - if package_name == "test_package.modules1.": - return [ - (None, "test_package.modules1.moduleA", False), - ] - elif package_name == "test_package.modules2.": - return [ - (None, "test_package.modules2.moduleB", False), - ] - else: - raise ValueError(f"Unknown package: {package_name}") - - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - TestAutoImporterClass.auto_import_package_modules() - - assert TestAutoImporterClass.auto_imported_modules == [ - "test_package.modules1.moduleA", - "test_package.modules2.moduleB", - ] +class MockHelper: + """Helper class to create consistent mock objects for testing.""" + @staticmethod + def create_mock_package(name: str, path: str): + """Create a mock package with required attributes.""" + package = mock.MagicMock() + package.__name__ = name + package.__path__ = [path] + return package -@pytest.mark.sanity -def test_ignore_modules(): - class TestAutoImporterClass(AutoImporterMixin): - auto_package = "test_package.modules" - auto_ignore_modules = ("test_package.modules.module1",) - - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package = mock.MagicMock() - mock_package.__path__ = ["test_package/modules"] - mock_package.__name__ = "test_package.modules" - - def import_module(name: str): - if name == "test_package.modules": - return mock_package - elif name == "test_package.modules.module1": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module1" - return module - elif name == "test_package.modules.module2": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module2" - return module - else: - raise ImportError(f"No module named {name}") - - def walk_packages(package_path, package_name): - if package_name == "test_package.modules.": - return [ - (None, "test_package.modules.module1", False), - (None, "test_package.modules.module2", False), - ] - else: - raise ValueError(f"Unknown package: {package_name}") - - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - TestAutoImporterClass.auto_import_package_modules() - - assert TestAutoImporterClass.auto_imported_modules == [ - "test_package.modules.module2", - ] + @staticmethod + def create_mock_module(name: str): + """Create a mock module with required attributes.""" + module = mock.MagicMock() + module.__name__ = name + return module diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py index de05dadc..7d22f352 100644 --- a/tests/unit/utils/test_pydantic_utils.py +++ b/tests/unit/utils/test_pydantic_utils.py @@ -1,245 +1,651 @@ """ -Unit tests for the pydantic_utils module in the Speculators library. +Unit tests for the pydantic_utils module. """ -from typing import ClassVar +from __future__ import annotations + +from typing import ClassVar, TypeVar from unittest import mock import pytest -from pydantic import BaseModel +from pydantic import BaseModel, Field, ValidationError from speculators.utils import PydanticClassRegistryMixin, ReloadableBaseModel - -# ===== ReloadableBaseModel Tests ===== - - -@pytest.mark.smoke -def test_reloadable_base_model_initialization(): - class TestModel(ReloadableBaseModel): - name: str - - model = TestModel(name="test") - assert model.name == "test" - - -@pytest.mark.smoke -def test_reloadable_base_model_reload_schema(): - class TestModel(ReloadableBaseModel): - name: str - - model = TestModel(name="test") - assert model.name == "test" - - # Mock the model_rebuild method to simulate schema reload - with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: - TestModel.reload_schema() - mock_rebuild.assert_called_once() - - -# ===== PydanticClassRegistryMixin Tests ===== - - -@pytest.mark.smoke -def test_pydantic_class_registry_subclass_init(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - return cls - - assert TestBaseModel.registry is None - assert TestBaseModel.schema_discriminator == "test_type" - - -@pytest.mark.smoke -def test_pydantic_class_registry_subclass_missing_base_type(): - class InvalidBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - with pytest.raises(TypeError): - InvalidBaseModel(test_type="test") # type: ignore[abstract] - - -@pytest.mark.sanity -def test_pydantic_class_registry_decorator(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register() - class TestSubModel(TestBaseModel): - test_type: str = "TestSubModel" - value: str - - assert TestBaseModel.registry is not None - assert "TestSubModel" in TestBaseModel.registry - assert TestBaseModel.registry["TestSubModel"] is TestSubModel - - -@pytest.mark.sanity -def test_pydantic_class_registry_decorator_with_name(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register("custom_name") - class TestSubModel(TestBaseModel): - test_type: str = "custom_name" - value: str - - assert TestBaseModel.registry is not None - assert "custom_name" in TestBaseModel.registry - assert TestBaseModel.registry["custom_name"] is TestSubModel - - -@pytest.mark.smoke -def test_pydantic_class_registry_decorator_invalid_type(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - class RegularClass: - pass - - with pytest.raises(TypeError) as exc_info: - TestBaseModel.register_decorator(RegularClass) # type: ignore[arg-type] - - assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) +from speculators.utils.pydantic_utils import BaseModelT, RegisterClassT @pytest.mark.smoke -def test_pydantic_class_registry_subclass_marshalling(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @TestBaseModel.register("test_sub") - class TestSubModel(TestBaseModel): - test_type: str = "test_sub" - value: str - - TestBaseModel.reload_schema() - - # Test direct construction of subclass - sub_instance = TestSubModel(value="test_value") - assert isinstance(sub_instance, TestSubModel) - assert sub_instance.test_type == "test_sub" - assert sub_instance.value == "test_value" - - # Test serialization with model_dump - dump_data = sub_instance.model_dump() - assert isinstance(dump_data, dict) - assert dump_data["test_type"] == "test_sub" - assert dump_data["value"] == "test_value" - - # Test deserialization via model_validate - recreated = TestSubModel.model_validate(dump_data) - assert isinstance(recreated, TestSubModel) - assert recreated.test_type == "test_sub" - assert recreated.value == "test_value" - - # Test polymorphic deserialization via base class - recreated = TestBaseModel.model_validate(dump_data) # type: ignore[assignment] - assert isinstance(recreated, TestSubModel) - assert recreated.test_type == "test_sub" - assert recreated.value == "test_value" +def test_base_model_t(): + """Test that BaseModelT is configured correctly as a TypeVar.""" + assert isinstance(BaseModelT, type(TypeVar("test"))) + assert BaseModelT.__name__ == "BaseModelT" + assert BaseModelT.__bound__ is BaseModel + assert BaseModelT.__constraints__ == () @pytest.mark.smoke -def test_pydantic_class_registry_parent_class_marshalling(): - class TestBaseModel(PydanticClassRegistryMixin): - schema_discriminator: ClassVar[str] = "test_type" - test_type: str - - @classmethod - def __pydantic_schema_base_type__(cls) -> type["TestBaseModel"]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel - - @classmethod - def __pydantic_generate_base_schema__(cls, handler): - return handler(cls) - - @TestBaseModel.register("sub_a") - class TestSubModelA(TestBaseModel): - test_type: str = "sub_a" - value_a: str - - @TestBaseModel.register("sub_b") - class TestSubModelB(TestBaseModel): - test_type: str = "sub_b" - value_b: int - - class ContainerModel(BaseModel): - name: str - model: TestBaseModel - models: list[TestBaseModel] - - sub_a = TestSubModelA(value_a="test") - sub_b = TestSubModelB(value_b=123) - - container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) - assert isinstance(container.model, TestSubModelA) - assert container.model.test_type == "sub_a" - assert container.model.value_a == "test" - assert isinstance(container.models[0], TestSubModelA) - assert isinstance(container.models[1], TestSubModelB) - assert container.models[0].test_type == "sub_a" - assert container.models[1].test_type == "sub_b" - assert container.models[0].value_a == "test" - assert container.models[1].value_b == 123 - - # Test serialization with model_dump - dump_data = container.model_dump() - assert isinstance(dump_data, dict) - assert dump_data["name"] == "container" - assert dump_data["model"]["test_type"] == "sub_a" - assert dump_data["model"]["value_a"] == "test" - assert len(dump_data["models"]) == 2 - assert dump_data["models"][0]["test_type"] == "sub_a" - assert dump_data["models"][0]["value_a"] == "test" - assert dump_data["models"][1]["test_type"] == "sub_b" - assert dump_data["models"][1]["value_b"] == 123 - - # Test deserialization via model_validate - recreated = ContainerModel.model_validate(dump_data) - assert isinstance(recreated, ContainerModel) - assert recreated.name == "container" - assert isinstance(recreated.model, TestSubModelA) - assert recreated.model.test_type == "sub_a" - assert recreated.model.value_a == "test" - assert len(recreated.models) == 2 - assert isinstance(recreated.models[0], TestSubModelA) - assert isinstance(recreated.models[1], TestSubModelB) - assert recreated.models[0].test_type == "sub_a" - assert recreated.models[1].test_type == "sub_b" - assert recreated.models[0].value_a == "test" - assert recreated.models[1].value_b == 123 +def test_register_class_t(): + """Test that RegisterClassT is configured correctly as a TypeVar.""" + assert isinstance(RegisterClassT, type(TypeVar("test"))) + assert RegisterClassT.__name__ == "RegisterClassT" + assert RegisterClassT.__bound__ is not None + assert RegisterClassT.__constraints__ == () + + +class TestReloadableBaseModel: + """Test suite for ReloadableBaseModel.""" + + @pytest.fixture( + params=[ + {"name": "test_value"}, + {"name": "hello_world"}, + {"name": "another_test"}, + ], + ids=["basic_string", "multi_word", "underscore"], + ) + def valid_instances(self, request) -> tuple[ReloadableBaseModel, dict[str, str]]: + """Fixture providing test data for ReloadableBaseModel.""" + + class TestModel(ReloadableBaseModel): + name: str + + constructor_args = request.param + instance = TestModel(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_class_signatures(self): + """Test ReloadableBaseModel inheritance and class variables.""" + assert issubclass(ReloadableBaseModel, BaseModel) + assert hasattr(ReloadableBaseModel, "model_config") + assert hasattr(ReloadableBaseModel, "reload_schema") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test ReloadableBaseModel initialization.""" + instance, constructor_args = valid_instances + assert isinstance(instance, ReloadableBaseModel) + assert instance.name == constructor_args["name"] # type: ignore[attr-defined] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("name", None), + ("name", 123), + ("name", []), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test ReloadableBaseModel with invalid field values.""" + + class TestModel(ReloadableBaseModel): + name: str + + data = {field: value} + with pytest.raises(ValidationError): + TestModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test ReloadableBaseModel initialization without required field.""" + + class TestModel(ReloadableBaseModel): + name: str + + with pytest.raises(ValidationError): + TestModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_reload_schema(self): + """Test ReloadableBaseModel.reload_schema method.""" + + class TestModel(ReloadableBaseModel): + name: str + + # Mock the model_rebuild method to simulate schema reload + with mock.patch.object(TestModel, "model_rebuild") as mock_rebuild: + TestModel.reload_schema() + mock_rebuild.assert_called_once_with(force=True) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test ReloadableBaseModel serialization and deserialization.""" + instance, constructor_args = valid_instances + data_dict = instance.model_dump() + assert isinstance(data_dict, dict) + assert data_dict["name"] == constructor_args["name"] + + recreated = instance.__class__.model_validate(data_dict) + assert isinstance(recreated, instance.__class__) + assert recreated.name == constructor_args["name"] + + +class TestPydanticClassRegistryMixin: + """Test suite for PydanticClassRegistryMixin.""" + + @pytest.fixture( + params=[ + {"test_type": "test_sub", "value": "test_value"}, + {"test_type": "test_sub", "value": "hello_world"}, + ], + ids=["basic_value", "multi_word"], + ) + def valid_instances( + self, request + ) -> tuple[PydanticClassRegistryMixin, dict, type, type]: + """Fixture providing test data for PydanticClassRegistryMixin.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + TestBaseModel.reload_schema() + + constructor_args = request.param + instance = TestSubModel(value=constructor_args["value"]) + return instance, constructor_args, TestBaseModel, TestSubModel + + @pytest.mark.smoke + def test_class_signatures(self): + """Test PydanticClassRegistryMixin inheritance and class variables.""" + assert issubclass(PydanticClassRegistryMixin, ReloadableBaseModel) + assert hasattr(PydanticClassRegistryMixin, "schema_discriminator") + assert PydanticClassRegistryMixin.schema_discriminator == "model_type" + assert hasattr(PydanticClassRegistryMixin, "register_decorator") + assert hasattr(PydanticClassRegistryMixin, "__get_pydantic_core_schema__") + assert hasattr(PydanticClassRegistryMixin, "__pydantic_generate_base_schema__") + assert hasattr(PydanticClassRegistryMixin, "auto_populate_registry") + assert hasattr(PydanticClassRegistryMixin, "registered_classes") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test PydanticClassRegistryMixin initialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + assert isinstance(instance, sub_class) + assert isinstance(instance, base_class) + assert instance.test_type == constructor_args["test_type"] + assert instance.value == constructor_args["value"] + + @pytest.mark.sanity + @pytest.mark.parametrize( + ("field", "value"), + [ + ("test_type", None), + ("test_type", 123), + ("value", None), + ], + ) + def test_invalid_initialization_values(self, field, value): + """Test PydanticClassRegistryMixin with invalid field values.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + data = {field: value} + if field == "test_type": + data["value"] = "test" + else: + data["test_type"] = "test_sub" + + with pytest.raises(ValidationError): + TestSubModel(**data) + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test PydanticClassRegistryMixin initialization without required field.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub") + class TestSubModel(TestBaseModel): + test_type: str = "test_sub" + value: str + + with pytest.raises(ValidationError): + TestSubModel() # type: ignore[call-arg] + + @pytest.mark.smoke + def test_register_decorator(self): + """Test PydanticClassRegistryMixin.register_decorator method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register() + class TestSubModel(TestBaseModel): + test_type: str = "TestSubModel" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "TestSubModel" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["TestSubModel"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_with_name(self): + """Test PydanticClassRegistryMixin.register_decorator with custom name.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("custom_name") + class TestSubModel(TestBaseModel): + test_type: str = "custom_name" + value: str + + assert TestBaseModel.registry is not None # type: ignore[misc] + assert "custom_name" in TestBaseModel.registry # type: ignore[misc] + assert TestBaseModel.registry["custom_name"] is TestSubModel # type: ignore[misc] + + @pytest.mark.sanity + def test_register_decorator_invalid_type(self): + """Test PydanticClassRegistryMixin.register_decorator with invalid type.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + class RegularClass: + pass + + with pytest.raises(TypeError) as exc_info: + TestBaseModel.register_decorator(RegularClass) # type: ignore[type-var] + + assert "not a subclass of Pydantic BaseModel" in str(exc_info.value) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test PydanticClassRegistryMixin.auto_populate_registry method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with ( + mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, + mock.patch( + "speculators.utils.registry.RegistryMixin.auto_populate_registry", + return_value=True, + ), + ): + result = TestBaseModel.auto_populate_registry() + assert result is True + mock_reload.assert_called_once() + + @pytest.mark.smoke + def test_registered_classes(self): + """Test PydanticClassRegistryMixin.registered_classes method.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = False + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @TestBaseModel.register("test_sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "test_sub_a" + value_a: str + + @TestBaseModel.register("test_sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "test_sub_b" + value_b: int + + # Test normal case with registered classes + registered = TestBaseModel.registered_classes() + assert isinstance(registered, tuple) + assert len(registered) == 2 + assert TestSubModelA in registered + assert TestSubModelB in registered + + @pytest.mark.sanity + def test_registered_classes_with_auto_discovery(self): + """Test PydanticClassRegistryMixin.registered_classes with auto discovery.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + registry_auto_discovery: ClassVar[bool] = True + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + with mock.patch.object( + TestBaseModel, "auto_populate_registry" + ) as mock_auto_populate: + # Mock the registry to simulate registered classes + TestBaseModel.registry = {"test_class": type("TestClass", (), {})} # type: ignore[misc] + mock_auto_populate.return_value = False + + registered = TestBaseModel.registered_classes() + mock_auto_populate.assert_called_once() + assert isinstance(registered, tuple) + assert len(registered) == 1 + + @pytest.mark.sanity + def test_registered_classes_no_registry(self): + """Test PydanticClassRegistryMixin.registered_classes with no registry.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + # Ensure registry is None + TestBaseModel.registry = None # type: ignore[misc] + + with pytest.raises(ValueError) as exc_info: + TestBaseModel.registered_classes() + + assert "must be called after registering classes" in str(exc_info.value) + + @pytest.mark.sanity + def test_marshalling(self, valid_instances): + """Test PydanticClassRegistryMixin serialization and deserialization.""" + instance, constructor_args, base_class, sub_class = valid_instances + + # Test serialization with model_dump + dump_data = instance.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["test_type"] == constructor_args["test_type"] + assert dump_data["value"] == constructor_args["value"] + + # Test deserialization via subclass + recreated = sub_class.model_validate(dump_data) + assert isinstance(recreated, sub_class) + assert recreated.test_type == constructor_args["test_type"] + assert recreated.value == constructor_args["value"] + + # Test polymorphic deserialization via base class + recreated_base = base_class.model_validate(dump_data) # type: ignore[assignment] + assert isinstance(recreated_base, sub_class) + assert recreated_base.test_type == constructor_args["test_type"] + assert recreated_base.value == constructor_args["value"] + + @pytest.mark.regression + def test_polymorphic_container_marshalling(self): + """Test PydanticClassRegistryMixin in container models.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "test_type" + test_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + return TestBaseModel + + @classmethod + def __pydantic_generate_base_schema__(cls, handler): + return handler(cls) + + @TestBaseModel.register("sub_a") + class TestSubModelA(TestBaseModel): + test_type: str = "sub_a" + value_a: str + + @TestBaseModel.register("sub_b") + class TestSubModelB(TestBaseModel): + test_type: str = "sub_b" + value_b: int + + class ContainerModel(BaseModel): + name: str + model: TestBaseModel + models: list[TestBaseModel] + + sub_a = TestSubModelA(value_a="test") + sub_b = TestSubModelB(value_b=123) + + container = ContainerModel(name="container", model=sub_a, models=[sub_a, sub_b]) + + # Verify container construction + assert isinstance(container.model, TestSubModelA) + assert container.model.test_type == "sub_a" + assert container.model.value_a == "test" + assert len(container.models) == 2 + assert isinstance(container.models[0], TestSubModelA) + assert isinstance(container.models[1], TestSubModelB) + + # Test serialization + dump_data = container.model_dump() + assert isinstance(dump_data, dict) + assert dump_data["name"] == "container" + assert dump_data["model"]["test_type"] == "sub_a" + assert dump_data["model"]["value_a"] == "test" + assert len(dump_data["models"]) == 2 + assert dump_data["models"][0]["test_type"] == "sub_a" + assert dump_data["models"][1]["test_type"] == "sub_b" + + # Test deserialization + recreated = ContainerModel.model_validate(dump_data) + assert isinstance(recreated, ContainerModel) + assert recreated.name == "container" + assert isinstance(recreated.model, TestSubModelA) + assert len(recreated.models) == 2 + assert isinstance(recreated.models[0], TestSubModelA) + assert isinstance(recreated.models[1], TestSubModelB) + + @pytest.mark.smoke + def test_register_preserves_pydantic_metadata(self): # noqa: C901 + """Test that registered Pydantic classes retain docs, types, and methods.""" + + class TestBaseModel(PydanticClassRegistryMixin): + schema_discriminator: ClassVar[str] = "model_type" + model_type: str + + @classmethod + def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: + if cls.__name__ == "TestBaseModel": + return cls + + return TestBaseModel + + @TestBaseModel.register("documented_model") + class DocumentedModel(TestBaseModel): + """This is a documented Pydantic model with methods and type hints.""" + + model_type: str = "documented_model" + value: int = Field(description="An integer value for the model") + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedModel: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedModel instance + """ + return cls(value=int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + def model_post_init(self, __context) -> None: + """Post-initialization processing. + + :param __context: Validation context + """ + if self.value < 0: + raise ValueError("Value must be non-negative") + + # Check that the class was registered + assert TestBaseModel.is_registered("documented_model") + registered_class = TestBaseModel.get_registered_object("documented_model") + assert registered_class is DocumentedModel + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented Pydantic model with methods" in registered_class.__doc__ + + # Check that methods retain their documentation + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) + assert registered_class.model_post_init.__doc__ is not None + assert ( + "Post-initialization processing" in registered_class.model_post_init.__doc__ + ) + + # Check that methods are callable and work correctly + instance = DocumentedModel(value=42) + assert isinstance(instance, DocumentedModel) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + assert instance.model_type == "documented_model" + + # Check class methods work + instance2 = DocumentedModel.from_string("123") + assert instance2.get_value() == 123 + assert instance2.model_type == "documented_model" + + # Check static methods work + assert DocumentedModel.validate_value(10) is True + assert DocumentedModel.validate_value(-5) is False + + # Check that Pydantic functionality is preserved + data_dict = instance.model_dump() + assert data_dict["value"] == 100 + assert data_dict["model_type"] == "documented_model" + + recreated = DocumentedModel.model_validate(data_dict) + assert isinstance(recreated, DocumentedModel) + assert recreated.value == 100 + assert recreated.model_type == "documented_model" + + # Test field validation + with pytest.raises(ValidationError): + DocumentedModel(value="not_an_int") # type: ignore[arg-type] + + # Test post_init validation + with pytest.raises(ValueError, match="Value must be non-negative"): + DocumentedModel(value=-10) + + # Check that Pydantic field metadata is preserved + value_field = DocumentedModel.model_fields["value"] + assert value_field.description == "An integer value for the model" + + # Check that type annotations are preserved (if accessible) + import inspect + + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(DocumentedModel.get_value) + return_ann = annotations.get("return") + assert return_ann is int or return_ann == "int" + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass + + # Check that the class name is preserved + assert DocumentedModel.__name__ == "DocumentedModel" + assert DocumentedModel.__qualname__.endswith("DocumentedModel") + + # Verify that the class is still properly integrated with the registry system + all_registered = TestBaseModel.registered_classes() + assert DocumentedModel in all_registered + + # Test that the registered class is the same as the original + assert registered_class is DocumentedModel diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py index e653c83e..bf78fcf6 100644 --- a/tests/unit/utils/test_registry.py +++ b/tests/unit/utils/test_registry.py @@ -1,310 +1,585 @@ """ -Unit tests for the registry module in the Speculators library. +Unit tests for the registry module. """ +from __future__ import annotations + +import inspect +from typing import TypeVar from unittest import mock import pytest -from speculators.utils.registry import ClassRegistryMixin - -# ===== ClassRegistryMixin Tests ===== - - -@pytest.mark.smoke -def test_class_registry_initialization(): - class TestRegistryClass(ClassRegistryMixin): - pass - - assert TestRegistryClass.registry is None - - -@pytest.mark.smoke -def test_register_with_name(): - class TestRegistryClass(ClassRegistryMixin): - pass - - @TestRegistryClass.register("custom_name") - class TestClass: - pass - - assert TestRegistryClass.registry is not None - assert "custom_name" in TestRegistryClass.registry - assert TestRegistryClass.registry["custom_name"] is TestClass - - -@pytest.mark.smoke -def test_register_without_name(): - class TestRegistryClass(ClassRegistryMixin): - pass - - @TestRegistryClass.register() - class TestClass: - pass - - assert TestRegistryClass.registry is not None - assert "TestClass" in TestRegistryClass.registry - assert TestRegistryClass.registry["TestClass"] is TestClass - - -@pytest.mark.smoke -def test_register_decorator_direct(): - class TestRegistryClass(ClassRegistryMixin): - pass +from speculators.utils import RegistryMixin +from speculators.utils.registry import RegisterT, RegistryObjT + + +def test_registry_obj_type(): + """Test that RegistryObjT is configured correctly as a TypeVar.""" + assert isinstance(RegistryObjT, type(TypeVar("test"))) + assert RegistryObjT.__name__ == "RegistryObjT" + assert RegistryObjT.__bound__ is None + assert RegistryObjT.__constraints__ == () + + +def test_registered_type(): + """Test that RegisterT is configured correctly as a TypeVar.""" + assert isinstance(RegisterT, type(TypeVar("test"))) + assert RegisterT.__name__ == "RegisterT" + assert RegisterT.__bound__ is None + assert RegisterT.__constraints__ == () + + +class TestRegistryMixin: + """Test suite for RegistryMixin class.""" + + @pytest.fixture( + params=[ + {"registry_auto_discovery": False, "auto_package": None}, + {"registry_auto_discovery": True, "auto_package": "test.package"}, + ], + ids=["manual_registry", "auto_discovery"], + ) + def valid_instances(self, request): + """Fixture providing test data for RegistryMixin subclasses.""" + config = request.param + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = config["registry_auto_discovery"] + if config["auto_package"]: + auto_package = config["auto_package"] + + return TestRegistryClass, config + + @pytest.mark.smoke + def test_class_signatures(self): + """Test RegistryMixin inheritance and exposed methods.""" + assert hasattr(RegistryMixin, "registry") + assert hasattr(RegistryMixin, "registry_auto_discovery") + assert hasattr(RegistryMixin, "registry_populated") + assert hasattr(RegistryMixin, "register") + assert hasattr(RegistryMixin, "register_decorator") + assert hasattr(RegistryMixin, "auto_populate_registry") + assert hasattr(RegistryMixin, "registered_objects") + assert hasattr(RegistryMixin, "is_registered") + assert hasattr(RegistryMixin, "get_registered_object") + + @pytest.mark.smoke + def test_initialization(self, valid_instances): + """Test RegistryMixin initialization.""" + registry_class, config = valid_instances + + assert registry_class.registry is None + assert ( + registry_class.registry_auto_discovery == config["registry_auto_discovery"] + ) + assert registry_class.registry_populated is False + + @pytest.mark.sanity + def test_invalid_initialization_missing(self): + """Test RegistryMixin with missing auto_package when auto_discovery enabled.""" + + class TestRegistryClass(RegistryMixin): + registry_auto_discovery = True + + with pytest.raises(ValueError, match="auto_package.*must be set"): + TestRegistryClass.auto_import_package_modules() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), + ], + ) + def test_register(self, valid_instances, name, expected_key): + """Test register method with various name configurations.""" + registry_class, _ = valid_instances + + @registry_class.register(name) + class TestClass: + pass - @TestRegistryClass.register_decorator - class TestClass: - pass + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_invalid(self, valid_instances, invalid_name): + """Test register method with invalid name types.""" + registry_class, _ = valid_instances + + # The register method returns a decorator, so we need to apply it to test + # validation + decorator = registry_class.register(invalid_name) + + class TestClass: + pass - assert TestRegistryClass.registry is not None - assert "TestClass" in TestRegistryClass.registry - assert TestRegistryClass.registry["TestClass"] is TestClass + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + decorator(TestClass) + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("name", "expected_key"), + [ + ("custom_name", "custom_name"), + (["name1", "name2"], ["name1", "name2"]), + (None, "TestClass"), + ], + ) + def test_register_decorator(self, valid_instances, name, expected_key): + """Test register_decorator method with various name configurations.""" + registry_class, _ = valid_instances + + class TestClass: + pass + registry_class.register_decorator(TestClass, name=name) + + assert registry_class.registry is not None + if isinstance(expected_key, list): + for key in expected_key: + assert key in registry_class.registry + assert registry_class.registry[key] is TestClass + else: + assert expected_key in registry_class.registry + assert registry_class.registry[expected_key] is TestClass + + @pytest.mark.sanity + @pytest.mark.parametrize( + "invalid_name", + [123, 42.5, True, {"key": "value"}], + ) + def test_register_decorator_invalid(self, valid_instances, invalid_name): + """Test register_decorator with invalid name types.""" + registry_class, _ = valid_instances + + class TestClass: + pass -@pytest.mark.sanity -def test_register_invalid_name_type(): - class TestRegistryClass(ClassRegistryMixin): - pass + with pytest.raises( + ValueError, match="name must be a string or an iterable of strings" + ): + registry_class.register_decorator(TestClass, name=invalid_name) + + @pytest.mark.smoke + def test_auto_populate_registry(self): + """Test auto_populate_registry method with valid configuration.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test.package" + + with mock.patch.object( + TestAutoRegistry, "auto_import_package_modules" + ) as mock_import: + result = TestAutoRegistry.auto_populate_registry() + assert result is True + mock_import.assert_called_once() + assert TestAutoRegistry.registry_populated is True + + # Second call should return False + result = TestAutoRegistry.auto_populate_registry() + assert result is False + mock_import.assert_called_once() + + @pytest.mark.sanity + def test_auto_populate_registry_invalid(self): + """Test auto_populate_registry when auto-discovery is disabled.""" + + class TestDisabledRegistry(RegistryMixin): + registry_auto_discovery = False + + with pytest.raises(ValueError, match="registry_auto_discovery is set to False"): + TestDisabledRegistry.auto_populate_registry() + + @pytest.mark.smoke + def test_registered_objects(self, valid_instances): + """Test registered_objects method with manual registration.""" + registry_class, config = valid_instances + + @registry_class.register("class1") + class TestClass1: + pass - with pytest.raises(ValueError) as exc_info: - TestRegistryClass.register(123) # type: ignore[arg-type] + @registry_class.register("class2") + class TestClass2: + pass - assert "name must be a string or None" in str(exc_info.value) + if config["registry_auto_discovery"]: + with mock.patch.object(registry_class, "auto_import_package_modules"): + objects = registry_class.registered_objects() + else: + objects = registry_class.registered_objects() + assert isinstance(objects, tuple) + assert len(objects) == 2 + assert TestClass1 in objects + assert TestClass2 in objects -@pytest.mark.sanity -def test_register_decorator_invalid_class(): - class TestRegistryClass(ClassRegistryMixin): - pass + @pytest.mark.sanity + def test_registered_objects_invalid(self): + """Test registered_objects when no objects are registered.""" - with pytest.raises(TypeError) as exc_info: - TestRegistryClass.register_decorator("not_a_class") # type: ignore[arg-type] + class TestRegistryClass(RegistryMixin): + pass - assert "must be used as a class decorator" in str(exc_info.value) + with pytest.raises( + ValueError, match="must be called after registering objects" + ): + TestRegistryClass.registered_objects() + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "check_name", "expected"), + [ + ("test_name", "test_name", True), + ("TestName", "testname", True), + ("UPPERCASE", "uppercase", True), + ("test_name", "nonexistent", False), + ], + ) + def test_is_registered(self, valid_instances, register_name, check_name, expected): + """Test is_registered with various name combinations.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass + result = registry_class.is_registered(check_name) + assert result == expected + + @pytest.mark.smoke + @pytest.mark.parametrize( + ("register_name", "lookup_name"), + [ + ("test_name", "test_name"), + ("TestName", "testname"), + ("UPPERCASE", "uppercase"), + ], + ) + def test_get_registered_object(self, valid_instances, register_name, lookup_name): + """Test get_registered_object with valid names.""" + registry_class, _ = valid_instances + + @registry_class.register(register_name) + class TestClass: + pass -@pytest.mark.sanity -def test_register_decorator_invalid_name(): - class TestRegistryClass(ClassRegistryMixin): - pass + result = registry_class.get_registered_object(lookup_name) + assert result is TestClass - class TestClass: - pass + @pytest.mark.sanity + @pytest.mark.parametrize( + "lookup_name", + ["nonexistent", "wrong_name", "DIFFERENT_CASE"], + ) + def test_get_registered_object_invalid(self, valid_instances, lookup_name): + """Test get_registered_object with invalid names.""" + registry_class, _ = valid_instances - with pytest.raises(ValueError) as exc_info: - TestRegistryClass.register_decorator(TestClass, name=123) # type: ignore[arg-type] + @registry_class.register("valid_name") + class TestClass: + pass - assert "must be used as a class decorator" in str(exc_info.value) + result = registry_class.get_registered_object(lookup_name) + assert result is None + @pytest.mark.regression + def test_multiple_registries_isolation(self): + """Test that different registry classes maintain separate registries.""" -@pytest.mark.sanity -def test_register_duplicate_name(): - class TestRegistryClass(ClassRegistryMixin): - pass + class Registry1(RegistryMixin): + pass - @TestRegistryClass.register("test_name") - class TestClass1: - pass + class Registry2(RegistryMixin): + pass - with pytest.raises(ValueError) as exc_info: + @Registry1.register() + class TestClass1: + pass - @TestRegistryClass.register("test_name") + @Registry2.register() class TestClass2: pass - assert "already registered" in str(exc_info.value) - - -@pytest.mark.sanity -def test_registered_classes_empty(): - class TestRegistryClass(ClassRegistryMixin): - pass - - with pytest.raises(ValueError) as exc_info: - TestRegistryClass.registered_classes() - - assert "must be called after registering classes" in str(exc_info.value) - - -@pytest.mark.sanity -def test_registered_classes(): - class TestRegistryClass(ClassRegistryMixin): - pass - - @TestRegistryClass.register() - class TestClass1: - pass - - @TestRegistryClass.register("custom_name") - class TestClass2: - pass + assert Registry1.registry is not None # type: ignore[misc] + assert Registry2.registry is not None # type: ignore[misc] + assert Registry1.registry != Registry2.registry # type: ignore[misc] + assert "TestClass1" in Registry1.registry # type: ignore[misc] + assert "TestClass2" in Registry2.registry # type: ignore[misc] + assert "TestClass1" not in Registry2.registry # type: ignore[misc] + assert "TestClass2" not in Registry1.registry # type: ignore[misc] + + @pytest.mark.smoke + def test_auto_discovery_initialization(self): + """Test initialization of auto-discovery enabled registry.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + assert TestAutoRegistry.registry is None # type: ignore[misc] + assert TestAutoRegistry.registry_populated is False + assert TestAutoRegistry.auto_package == "test_package.modules" + assert TestAutoRegistry.registry_auto_discovery is True + + @pytest.mark.smoke + def test_auto_discovery_registered_objects(self): + """Test automatic population during registered_objects call.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with mock.patch.object( + TestAutoRegistry, "auto_populate_registry" + ) as mock_populate: + TestAutoRegistry.registry = {"class1": "obj1", "class2": "obj2"} # type: ignore[misc] + objects = TestAutoRegistry.registered_objects() + mock_populate.assert_called_once() + assert objects == ("obj1", "obj2") + + @pytest.mark.sanity + def test_register_duplicate_registration(self, valid_instances): + """Test register method with duplicate names.""" + registry_class, _ = valid_instances + + @registry_class.register("duplicate_name") + class TestClass1: + pass - registered = TestRegistryClass.registered_classes() - assert isinstance(registered, tuple) - assert len(registered) == 2 - assert TestClass1 in registered - assert TestClass2 in registered + with pytest.raises(ValueError, match="already registered"): + @registry_class.register("duplicate_name") + class TestClass2: + pass -@pytest.mark.regression -def test_multiple_registries_isolation(): - class Registry1(ClassRegistryMixin): - pass + @pytest.mark.sanity + def test_register_decorator_duplicate_registration(self, valid_instances): + """Test register_decorator with duplicate names.""" + registry_class, _ = valid_instances - class Registry2(ClassRegistryMixin): - pass + class TestClass1: + pass - @Registry1.register() - class TestClass1: - pass + class TestClass2: + pass - @Registry2.register() - class TestClass2: - pass + registry_class.register_decorator(TestClass1, name="duplicate_name") + with pytest.raises(ValueError, match="already registered"): + registry_class.register_decorator(TestClass2, name="duplicate_name") - assert Registry1.registry is not None - assert Registry2.registry is not None - assert Registry1.registry != Registry2.registry - assert "TestClass1" in Registry1.registry - assert "TestClass2" in Registry2.registry - assert "TestClass1" not in Registry2.registry - assert "TestClass2" not in Registry1.registry + @pytest.mark.sanity + def test_register_decorator_invalid_list_element(self, valid_instances): + """Test register_decorator with invalid elements in name list.""" + registry_class, _ = valid_instances + class TestClass: + pass -# ===== Auto-Discovery Tests ===== + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", 123]) + @pytest.mark.sanity + def test_register_decorator_empty_string_name(self, valid_instances): + """Test register_decorator with empty string name.""" + registry_class, _ = valid_instances -@pytest.mark.smoke -def test_auto_discovery_registry_initialization(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" + class TestClass: + pass - assert TestAutoRegistry.registry is None - assert TestAutoRegistry.registry_populated is False - assert TestAutoRegistry.auto_package == "test_package.modules" - assert TestAutoRegistry.registry_auto_discovery is True + registry_class.register_decorator(TestClass, name="") + assert "" in registry_class.registry + assert registry_class.registry[""] is TestClass + @pytest.mark.sanity + def test_register_decorator_none_in_list(self, valid_instances): + """Test register_decorator with None in name list.""" + registry_class, _ = valid_instances -@pytest.mark.smoke -def test_auto_populate_registry(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" + class TestClass: + pass - with mock.patch.object( - TestAutoRegistry, "auto_import_package_modules" - ) as mock_import: - TestAutoRegistry.auto_populate_registry() - mock_import.assert_called_once() - assert TestAutoRegistry.registry_populated is True + with pytest.raises( + ValueError, match="name must be a string or a list of strings" + ): + registry_class.register_decorator(TestClass, name=["valid", None]) + + @pytest.mark.smoke + def test_is_registered_empty_registry(self, valid_instances): + """Test is_registered with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.is_registered("any_name") + assert result is False + + @pytest.mark.smoke + def test_get_registered_object_empty_registry(self, valid_instances): + """Test get_registered_object with empty registry.""" + registry_class, _ = valid_instances + + result = registry_class.get_registered_object("any_name") + assert result is None + + @pytest.mark.regression + def test_auto_registry_integration(self): + """Test complete auto-discovery workflow with mocked imports.""" + + class TestAutoRegistry(RegistryMixin): + registry_auto_discovery = True + auto_package = "test_package.modules" + + with ( + mock.patch("pkgutil.walk_packages") as mock_walk, + mock.patch("importlib.import_module") as mock_import, + ): + mock_package = mock.MagicMock() + mock_package.__path__ = ["test_package/modules"] + mock_package.__name__ = "test_package.modules" + + def import_module(name: str): + if name == "test_package.modules": + return mock_package + elif name == "test_package.modules.module1": + module = mock.MagicMock() + module.__name__ = "test_package.modules.module1" + + class Module1Class: + pass + + TestAutoRegistry.register_decorator(Module1Class, "Module1Class") + return module + else: + raise ImportError(f"No module named {name}") + + def walk_packages(package_path, package_name): + if package_name == "test_package.modules.": + return [(None, "test_package.modules.module1", False)] + else: + raise ValueError(f"Unknown package: {package_name}") + + mock_walk.side_effect = walk_packages + mock_import.side_effect = import_module + + objects = TestAutoRegistry.registered_objects() + assert len(objects) == 1 + assert TestAutoRegistry.registry_populated is True + assert TestAutoRegistry.registry is not None # type: ignore[misc] + assert "Module1Class" in TestAutoRegistry.registry # type: ignore[misc] + + @pytest.mark.smoke + def test_register_preserves_class_metadata(self): + """Test that registered classes retain docs, types, and methods.""" + + class TestRegistry(RegistryMixin): + pass - # Second call should not trigger another import since already populated - TestAutoRegistry.auto_populate_registry() - mock_import.assert_called_once() - - -@pytest.mark.sanity -def test_auto_populate_registry_disabled(): - class TestDisabledAutoRegistry(ClassRegistryMixin): - # registry_auto_discovery is False by default - auto_package = "test_package.modules" - - with pytest.raises(ValueError) as exc_info: - TestDisabledAutoRegistry.auto_populate_registry() - - assert "registry_auto_discovery is set to False" in str(exc_info.value) - - -@pytest.mark.sanity -def test_auto_registered_classes(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" - - with mock.patch.object(TestAutoRegistry, "auto_populate_registry") as mock_populate: - # Mock the registry content - TestAutoRegistry.registry = {"Class1": "class1", "Class2": "class2"} # type: ignore[dict-item] - classes = TestAutoRegistry.registered_classes() - mock_populate.assert_called_once() - assert classes == ("class1", "class2") - - -@pytest.mark.regression -def test_auto_registry_integration(): - class TestAutoRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = "test_package.modules" - - with ( - mock.patch("pkgutil.walk_packages") as mock_walk, - mock.patch("importlib.import_module") as mock_import, - ): - # Create a mock package with the necessary attributes - mock_package = mock.MagicMock() - mock_package.__path__ = ["test_package/modules"] - mock_package.__name__ = "test_package.modules" - - def import_module(name: str): - if name == "test_package.modules": - return mock_package - elif name == "test_package.modules.module1": - module = mock.MagicMock() - module.__name__ = "test_package.modules.module1" - - class Module1Class: - pass - - TestAutoRegistry.register_decorator(Module1Class, "Module1Class") - return module - else: - raise ImportError(f"No module named {name}") - - def walk_packages(package_path, package_name): - if package_name == "test_package.modules.": - return [(None, "test_package.modules.module1", False)] - else: - raise ValueError(f"Unknown package: {package_name}") - - mock_walk.side_effect = walk_packages - mock_import.side_effect = import_module - - classes = TestAutoRegistry.registered_classes() - assert len(classes) == 1 - assert TestAutoRegistry.registry_populated is True - assert TestAutoRegistry.registry is not None - assert "Module1Class" in TestAutoRegistry.registry - - -@pytest.mark.regression -def test_auto_registry_with_multiple_packages(): - class TestMultiPackageRegistry(ClassRegistryMixin): - registry_auto_discovery = True - auto_package = ("package1", "package2") - - with mock.patch.object( - TestMultiPackageRegistry, "auto_import_package_modules" - ) as mock_import: - # Mock the registry to avoid ValueError when getting registered_classes - TestMultiPackageRegistry.registry = {} - TestMultiPackageRegistry.registered_classes() - mock_import.assert_called_once() - assert TestMultiPackageRegistry.registry_populated is True - - -@pytest.mark.regression -def test_auto_registry_no_package(): - class TestNoPackageRegistry(ClassRegistryMixin): - registry_auto_discovery = True - # No auto_package defined - - with mock.patch.object( - TestNoPackageRegistry, - "auto_import_package_modules", - side_effect=ValueError("auto_package must be set"), - ) as mock_import: - with pytest.raises(ValueError) as exc_info: - TestNoPackageRegistry.auto_populate_registry() - - mock_import.assert_called_once() - assert "auto_package must be set" in str(exc_info.value) + @TestRegistry.register("documented_class") + class DocumentedClass: + """This is a documented class with methods and type hints.""" + + def __init__(self, value: int) -> None: + """Initialize with a value. + + :param value: An integer value + """ + self.value = value + + def get_value(self) -> int: + """Get the stored value. + + :return: The stored integer value + """ + return self.value + + def set_value(self, new_value: int) -> None: + """Set a new value. + + :param new_value: The new integer value to set + """ + self.value = new_value + + @classmethod + def from_string(cls, value_str: str) -> DocumentedClass: + """Create instance from string. + + :param value_str: String representation of value + :return: New DocumentedClass instance + """ + return cls(int(value_str)) + + @staticmethod + def validate_value(value: int) -> bool: + """Validate that a value is positive. + + :param value: Value to validate + :return: True if positive, False otherwise + """ + return value > 0 + + # Check that the class was registered + assert TestRegistry.is_registered("documented_class") + registered_class = TestRegistry.get_registered_object("documented_class") + assert registered_class is DocumentedClass + + # Check that the class retains its documentation + assert registered_class.__doc__ is not None + assert "documented class with methods" in registered_class.__doc__ + assert registered_class.__init__.__doc__ is not None + assert "Initialize with a value" in registered_class.__init__.__doc__ + assert registered_class.get_value.__doc__ is not None + assert "Get the stored value" in registered_class.get_value.__doc__ + assert registered_class.set_value.__doc__ is not None + assert "Set a new value" in registered_class.set_value.__doc__ + assert registered_class.from_string.__doc__ is not None + assert "Create instance from string" in registered_class.from_string.__doc__ + assert registered_class.validate_value.__doc__ is not None + assert ( + "Validate that a value is positive" + in registered_class.validate_value.__doc__ + ) + + # Check that methods are callable and work correctly + instance = registered_class(42) + assert instance.get_value() == 42 + instance.set_value(100) + assert instance.get_value() == 100 + instance2 = registered_class.from_string("123") + assert instance2.get_value() == 123 + assert registered_class.validate_value(10) is True + assert registered_class.validate_value(-5) is False + + # Check that type annotations are preserved (if accessible) + if hasattr(inspect, "get_annotations"): + # Python 3.10+ + try: + annotations = inspect.get_annotations(registered_class.__init__) + assert "value" in annotations + assert annotations["value"] is int + return_ann = annotations.get("return") + assert return_ann is None or return_ann is type(None) + except (AttributeError, NameError): + # Fallback for older Python or missing annotations + pass + + # Check that the class name is preserved + assert registered_class.__name__ == "DocumentedClass" + assert registered_class.__qualname__.endswith("DocumentedClass") From 9676c610f21cffc6acb92c488c47e23b004303c0 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Thu, 4 Sep 2025 15:47:25 -0700 Subject: [PATCH 3/8] expand pydantic utils to auto reload parent modules when new registration happens Signed-off-by: Mark Kurtz --- src/speculators/utils/pydantic_utils.py | 82 ++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 2 deletions(-) diff --git a/src/speculators/utils/pydantic_utils.py b/src/speculators/utils/pydantic_utils.py index 6d074796..abe9ee6a 100644 --- a/src/speculators/utils/pydantic_utils.py +++ b/src/speculators/utils/pydantic_utils.py @@ -11,7 +11,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, ClassVar, Generic, TypeVar +from typing import Any, ClassVar, Generic, TypeVar, get_args, get_origin from pydantic import BaseModel, GetCoreSchemaHandler from pydantic_core import CoreSchema, core_schema @@ -41,15 +41,93 @@ class ReloadableBaseModel(BaseModel): """ @classmethod - def reload_schema(cls) -> None: + def reload_schema(cls, parents: bool = True) -> None: """ Reload the class schema with updated registry information. Forces a complete rebuild of the Pydantic model schema to incorporate any changes made to associated registries or validation rules. + + :param parents: Whether to also rebuild schemas for any pydantic parent + types that reference this model. """ cls.model_rebuild(force=True) + if parents: + cls.reload_parent_schemas() + + @classmethod + def reload_parent_schemas(cls): + """ + Recursively reload schemas for all parent Pydantic models. + + Traverses the inheritance hierarchy to find all parent classes that + are Pydantic models and triggers schema rebuilding on each to ensure + that any changes in child models are reflected in parent schemas. + """ + potential_parents: set[type[BaseModel]] = {BaseModel} + stack: list[type[BaseModel]] = [BaseModel] + + while stack: + current = stack.pop() + for subclass in current.__subclasses__(): + if ( + issubclass(subclass, BaseModel) + and subclass is not cls + and subclass not in potential_parents + ): + potential_parents.add(subclass) + stack.append(subclass) + + for check in cls.__mro__: + if isinstance(check, type) and issubclass(check, BaseModel): + cls._reload_schemas_depending_on(check, potential_parents) + + @classmethod + def _reload_schemas_depending_on(cls, target: type[BaseModel], types: set[type]): + changed = True + while changed: + changed = False + for candidate in types: + if ( + isinstance(candidate, type) + and issubclass(candidate, BaseModel) + and any( + cls._uses_type(target, field_info.annotation) + for field_info in candidate.model_fields.values() + if field_info.annotation is not None + ) + ): + try: + before = candidate.model_json_schema() + except Exception: # noqa: BLE001 + before = None + candidate.model_rebuild(force=True) + if before is not None: + after = candidate.model_json_schema() + changed |= before != after + + @classmethod + def _uses_type(cls, target: type, candidate: type) -> bool: + if target is candidate: + return True + + origin = get_origin(candidate) + + if origin is None: + return isinstance(candidate, type) and issubclass(candidate, target) + + if isinstance(origin, type) and ( + target is origin or issubclass(origin, target) + ): + return True + + for arg in get_args(candidate) or []: + if isinstance(arg, type) and cls._uses_type(target, arg): + return True + + return False + class PydanticClassRegistryMixin( ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] From c67c12f0ac3d52df0a56fb1558de56dfe5245ffc Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Mon, 8 Sep 2025 17:21:34 -0600 Subject: [PATCH 4/8] Rebase on latest main and fixes for unit and integration tests Signed-off-by: Mark Kurtz --- src/speculators/convert/eagle/__init__.py | 5 +++-- src/speculators/convert/eagle/eagle3_converter.py | 2 ++ src/speculators/convert/eagle/eagle_converter.py | 2 ++ src/speculators/utils/auto_importer.py | 10 ++++------ src/speculators/utils/registry.py | 4 ++-- tests/unit/convert/test_eagle_utils.py | 2 +- 6 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/speculators/convert/eagle/__init__.py b/src/speculators/convert/eagle/__init__.py index 591498a7..9d007b3c 100644 --- a/src/speculators/convert/eagle/__init__.py +++ b/src/speculators/convert/eagle/__init__.py @@ -1,7 +1,8 @@ """ -Eagle-1 and Eagle-3 checkpoint conversion utilities. +EAGLE v1, EAGLE v2, EAGLE v3, and HASS checkpoint conversion utilities. """ +from speculators.convert.eagle.eagle3_converter import Eagle3Converter from speculators.convert.eagle.eagle_converter import EagleConverter -__all__ = ["EagleConverter"] +__all__ = ["Eagle3Converter", "EagleConverter"] diff --git a/src/speculators/convert/eagle/eagle3_converter.py b/src/speculators/convert/eagle/eagle3_converter.py index 95bd2c6b..7de975d6 100644 --- a/src/speculators/convert/eagle/eagle3_converter.py +++ b/src/speculators/convert/eagle/eagle3_converter.py @@ -18,6 +18,8 @@ from speculators.models.eagle3 import Eagle3Speculator, Eagle3SpeculatorConfig from speculators.proposals.greedy import GreedyTokenProposalConfig +__all__ = ["Eagle3Converter"] + class Eagle3Converter: """ diff --git a/src/speculators/convert/eagle/eagle_converter.py b/src/speculators/convert/eagle/eagle_converter.py index 3d2e4a19..8bf3d377 100644 --- a/src/speculators/convert/eagle/eagle_converter.py +++ b/src/speculators/convert/eagle/eagle_converter.py @@ -19,6 +19,8 @@ from speculators.models.eagle import EagleSpeculator, EagleSpeculatorConfig from speculators.proposals.greedy import GreedyTokenProposalConfig +__all__ = ["EagleConverter"] + class EagleConverter: """ diff --git a/src/speculators/utils/auto_importer.py b/src/speculators/utils/auto_importer.py index 88e11c85..254c2bd6 100644 --- a/src/speculators/utils/auto_importer.py +++ b/src/speculators/utils/auto_importer.py @@ -8,12 +8,10 @@ discovered when placed in the correct package structure. """ -from __future__ import annotations - import importlib import pkgutil import sys -from typing import ClassVar +from typing import ClassVar, Union __all__ = ["AutoImporterMixin"] @@ -41,9 +39,9 @@ class MyRegistry(AutoImporterMixin): :cvar auto_imported_modules: List tracking which modules have been imported """ - auto_package: ClassVar[str | tuple[str, ...] | None] = None - auto_ignore_modules: ClassVar[tuple[str, ...] | None] = None - auto_imported_modules: ClassVar[list[str] | None] = None + auto_package: ClassVar[Union[str, tuple[str, ...], None]] = None + auto_ignore_modules: ClassVar[Union[tuple[str, ...], None]] = None + auto_imported_modules: ClassVar[Union[list[str], None]] = None @classmethod def auto_import_package_modules(cls) -> None: diff --git a/src/speculators/utils/registry.py b/src/speculators/utils/registry.py index 34b74549..05448591 100644 --- a/src/speculators/utils/registry.py +++ b/src/speculators/utils/registry.py @@ -10,7 +10,7 @@ from __future__ import annotations -from typing import Callable, ClassVar, Generic, TypeVar, cast +from typing import Callable, ClassVar, Generic, TypeVar, Union, cast from speculators.utils.auto_importer import AutoImporterMixin @@ -62,7 +62,7 @@ class TokenProposal(RegistryMixin): :cvar registry_populated: Track whether auto-discovery has completed """ - registry: ClassVar[dict[str, RegistryObjT] | None] = None # type: ignore[misc] + registry: ClassVar[Union[dict[str, RegistryObjT], None]] = None # type: ignore[misc] # noqa: UP007 registry_auto_discovery: ClassVar[bool] = False registry_populated: ClassVar[bool] = False diff --git a/tests/unit/convert/test_eagle_utils.py b/tests/unit/convert/test_eagle_utils.py index facdcdbd..40532962 100644 --- a/tests/unit/convert/test_eagle_utils.py +++ b/tests/unit/convert/test_eagle_utils.py @@ -93,7 +93,7 @@ def test_download_with_cache_dir(self, mock_download, tmp_path): ensure_checkpoint_is_local("test-model/checkpoint", cache_dir=cache_dir) mock_download.assert_called_once_with( - model_id="test-model/checkpoint", cache_dir=cache_dir + model_id="test-model/checkpoint", cache_dir=str(cache_dir) ) From 2c6428ecfd75f49a51edfc9c196bdde34dfcd1f6 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Mon, 8 Sep 2025 17:44:28 -0600 Subject: [PATCH 5/8] fixes for latest python versions Signed-off-by: Mark Kurtz --- tests/unit/utils/test_auto_importer.py | 3 --- tests/unit/utils/test_registry.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/unit/utils/test_auto_importer.py b/tests/unit/utils/test_auto_importer.py index ee3d68f9..dcae83c3 100644 --- a/tests/unit/utils/test_auto_importer.py +++ b/tests/unit/utils/test_auto_importer.py @@ -215,9 +215,6 @@ class TestClass(AutoImporterMixin): # Verify assert TestClass.auto_imported_modules == ["test.package.existing"] - mock_import.assert_called_once_with("test.package") - with pytest.raises(AssertionError): - mock_import.assert_any_call("test.package.existing") @pytest.mark.sanity @mock.patch("importlib.import_module") diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py index bf78fcf6..cd2101f4 100644 --- a/tests/unit/utils/test_registry.py +++ b/tests/unit/utils/test_registry.py @@ -573,7 +573,7 @@ def validate_value(value: int) -> bool: try: annotations = inspect.get_annotations(registered_class.__init__) assert "value" in annotations - assert annotations["value"] is int + assert annotations["value"] in (int, "int") return_ann = annotations.get("return") assert return_ann is None or return_ann is type(None) except (AttributeError, NameError): From 52599991a7c8cbf7b443246d2d4307385234b406 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Mon, 8 Sep 2025 17:54:58 -0600 Subject: [PATCH 6/8] Fixes for newer python version tests Signed-off-by: Mark Kurtz --- tests/unit/utils/test_registry.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/unit/utils/test_registry.py b/tests/unit/utils/test_registry.py index cd2101f4..393ef4ba 100644 --- a/tests/unit/utils/test_registry.py +++ b/tests/unit/utils/test_registry.py @@ -575,7 +575,11 @@ def validate_value(value: int) -> bool: assert "value" in annotations assert annotations["value"] in (int, "int") return_ann = annotations.get("return") - assert return_ann is None or return_ann is type(None) + assert ( + return_ann == "None" + or return_ann is None + or return_ann is type(None) + ) except (AttributeError, NameError): # Fallback for older Python or missing annotations pass From f8b908204683c8b41df3807e7b4d9d4fb43afd87 Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Mon, 15 Sep 2025 12:30:31 -0400 Subject: [PATCH 7/8] Improvements to pydantic ReloadableBaseModel and PydanticClassRegistryMixin from review Signed-off-by: Mark Kurtz --- src/speculators/config.py | 14 +- src/speculators/utils/pydantic_utils.py | 224 +++++++++++++++--------- tests/unit/utils/test_pydantic_utils.py | 73 +++----- 3 files changed, 166 insertions(+), 145 deletions(-) diff --git a/src/speculators/config.py b/src/speculators/config.py index 36acea4c..c4a55f46 100644 --- a/src/speculators/config.py +++ b/src/speculators/config.py @@ -50,11 +50,8 @@ class TokenProposalConfig(PydanticClassRegistryMixin): """ @classmethod - def __pydantic_schema_base_type__(cls) -> type["TokenProposalConfig"]: - if cls.__name__ == "TokenProposalConfig": - return cls - - return TokenProposalConfig + def __pydantic_schema_base_name__(cls) -> str: + return "TokenProposalConfig" auto_package: ClassVar[str] = "speculators.proposals" registry_auto_discovery: ClassVar[bool] = True @@ -238,11 +235,8 @@ def from_dict( return cls.model_validate(dict_obj) @classmethod - def __pydantic_schema_base_type__(cls) -> type["SpeculatorModelConfig"]: - if cls.__name__ == "SpeculatorModelConfig": - return cls - - return SpeculatorModelConfig + def __pydantic_schema_base_name__(cls) -> str: + return "SpeculatorModelConfig" # Pydantic configuration model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") diff --git a/src/speculators/utils/pydantic_utils.py b/src/speculators/utils/pydantic_utils.py index abe9ee6a..74aed552 100644 --- a/src/speculators/utils/pydantic_utils.py +++ b/src/speculators/utils/pydantic_utils.py @@ -41,104 +41,152 @@ class ReloadableBaseModel(BaseModel): """ @classmethod - def reload_schema(cls, parents: bool = True) -> None: + def reload_schema(cls, dependencies: bool = True): """ - Reload the class schema with updated registry information. + Reload and rebuild the Pydantic model validation schema. - Forces a complete rebuild of the Pydantic model schema to incorporate - any changes made to associated registries or validation rules. + Forces reconstruction of the model schema and optionally rebuilds + schemas for all dependent models in the reloadable dependency chains. + Essential when new types are registered that affect polymorphic validation. - :param parents: Whether to also rebuild schemas for any pydantic parent - types that reference this model. + :param dependencies: Whether to reload dependent model schemas as well """ cls.model_rebuild(force=True) - if parents: - cls.reload_parent_schemas() + if dependencies: + for chain in cls.reloadable_dependency_chains(): + for clazz in chain: + clazz.model_rebuild(force=True) @classmethod - def reload_parent_schemas(cls): + def reloadable_dependency_chains( + cls, target: type[ReloadableBaseModel] | None = None + ) -> list[list[type[ReloadableBaseModel]]]: """ - Recursively reload schemas for all parent Pydantic models. + Find all dependency chains leading to the target model class. - Traverses the inheritance hierarchy to find all parent classes that - are Pydantic models and triggers schema rebuilding on each to ensure - that any changes in child models are reflected in parent schemas. + Uses depth-first search to identify dependency paths between reloadable + models, ensuring proper schema reload ordering to maintain validation + consistency across the polymorphic model hierarchy. + + :param target: Target model class to find chains for. Uses cls if None + :return: List of dependency chains ending at the target class """ - potential_parents: set[type[BaseModel]] = {BaseModel} - stack: list[type[BaseModel]] = [BaseModel] + if target is None: + target = cls - while stack: - current = stack.pop() - for subclass in current.__subclasses__(): - if ( - issubclass(subclass, BaseModel) - and subclass is not cls - and subclass not in potential_parents - ): - potential_parents.add(subclass) - stack.append(subclass) + # Build a map of all reloadable classes to their dependencies + dependencies: dict[ + type[ReloadableBaseModel], list[type[ReloadableBaseModel]] + ] = {} + + for reloadable in cls.reloadable_descendants(ReloadableBaseModel): + deps = [] + for field_deps in reloadable.reloadable_fields().values(): + deps.extend(field_deps) + dependencies[reloadable] = deps + + # Find all dependency chains ending at target using DFS + chains = [] - for check in cls.__mro__: - if isinstance(check, type) and issubclass(check, BaseModel): - cls._reload_schemas_depending_on(check, potential_parents) + def _find_chains( + current: type[ReloadableBaseModel], path: list[type[ReloadableBaseModel]] + ): + if current == target: + chains.append(path) + return + + for dependent in dependencies.get(current, []): + if dependent not in path: # Avoid cycles + _find_chains(dependent, [current] + path) + + for cls_type, deps in dependencies.items(): + if deps and cls_type != target: + _find_chains(cls_type, []) + + return chains @classmethod - def _reload_schemas_depending_on(cls, target: type[BaseModel], types: set[type]): - changed = True - while changed: - changed = False - for candidate in types: - if ( - isinstance(candidate, type) - and issubclass(candidate, BaseModel) - and any( - cls._uses_type(target, field_info.annotation) - for field_info in candidate.model_fields.values() - if field_info.annotation is not None - ) - ): - try: - before = candidate.model_json_schema() - except Exception: # noqa: BLE001 - before = None - candidate.model_rebuild(force=True) - if before is not None: - after = candidate.model_json_schema() - changed |= before != after + def reloadable_fields( + cls, + ) -> dict[str, list[type[ReloadableBaseModel]]]: + """ + Identify model fields containing reloadable model types. + + Recursively analyzes field type annotations to find all ReloadableBaseModel + subclasses used within the model schema, enabling dependency tracking for + proper schema reload ordering. + + :return: Mapping of field names to lists of reloadable model types + """ + + def _recursive_resolve_reloadable_types(type_: type | None) -> list[type]: + if type_ is None: + return [] + + if (origin := get_origin(type_)) is None: + return [type_] if issubclass(type_, ReloadableBaseModel) else [] + + resolved = [] + if issubclass(origin, ReloadableBaseModel): + resolved.append(origin) + + for arg in get_args(type_): + resolved.extend(_recursive_resolve_reloadable_types(arg)) + + return resolved + + fields = {} + + for name, info in cls.model_fields.items(): + if reloadable_types := _recursive_resolve_reloadable_types(info.annotation): + fields[name] = reloadable_types + + return fields @classmethod - def _uses_type(cls, target: type, candidate: type) -> bool: - if target is candidate: - return True + def reloadable_descendants( + cls, target: type[ReloadableBaseModel] | None = None + ) -> set[type[ReloadableBaseModel]]: + """ + Find all ReloadableBaseModel descendants of the target class. - origin = get_origin(candidate) + Traverses the inheritance hierarchy to collect all subclasses that inherit + from ReloadableBaseModel, enabling comprehensive dependency analysis for + schema reloading operations. - if origin is None: - return isinstance(candidate, type) and issubclass(candidate, target) + :param target: Base class to find descendants for. Uses cls if None + :return: Set of all descendant ReloadableBaseModel classes + """ + if target is None: + target = cls - if isinstance(origin, type) and ( - target is origin or issubclass(origin, target) - ): - return True + descendants: set[type[ReloadableBaseModel]] = set() + stack: list[type[ReloadableBaseModel]] = [target] - for arg in get_args(candidate) or []: - if isinstance(arg, type) and cls._uses_type(target, arg): - return True + while stack: + current = stack.pop() + for subclass in current.__subclasses__(): + if ( + issubclass(subclass, ReloadableBaseModel) + and subclass is not cls + and subclass not in descendants + ): + descendants.add(subclass) + stack.append(subclass) - return False + return descendants class PydanticClassRegistryMixin( ReloadableBaseModel, RegistryMixin[type[BaseModelT]], ABC, Generic[BaseModelT] ): """ - Polymorphic Pydantic model mixin enabling registry-based dynamic instantiation. + Polymorphic Pydantic model enabling registry-based dynamic type instantiation. - Integrates Pydantic validation with the registry system to enable polymorphic - serialization and deserialization based on a discriminator field. Automatically - instantiates the correct subclass during validation based on registry mappings, - providing a foundation for extensible plugin-style architectures. + Integrates Pydantic validation with the registry system for polymorphic + serialization and deserialization using a discriminator field. Automatically + instantiates the correct subclass during validation based on registry mappings. Example: :: @@ -149,8 +197,8 @@ class BaseConfig(PydanticClassRegistryMixin["BaseConfig"]): config_type: str = Field(description="Configuration type identifier") @classmethod - def __pydantic_schema_base_type__(cls) -> type["BaseConfig"]: - return BaseConfig + def __pydantic_schema_base_name__(cls) -> str: + return "BaseConfig" @BaseConfig.register("database") class DatabaseConfig(BaseConfig): @@ -163,7 +211,7 @@ class DatabaseConfig(BaseConfig): "connection_string": "postgresql://localhost:5432/db" }) - :cvar schema_discriminator: Field name used for polymorphic type discrimination + :cvar schema_discriminator: Field name for polymorphic type discrimination """ schema_discriminator: ClassVar[str] = "model_type" @@ -210,31 +258,35 @@ def __get_pydantic_core_schema__( :param handler: Pydantic core schema generation handler :return: Tagged union schema for polymorphic validation or base schema """ - if source_type == cls.__pydantic_schema_base_type__(): - if not cls.registry: - return cls.__pydantic_generate_base_schema__(handler) + if ( + source_type is None + or not isinstance(source_type, type) + or source_type.__name__ != cls.__pydantic_schema_base_name__() + ): + return handler(cls) - choices = { - name: handler(model_class) for name, model_class in cls.registry.items() - } + if not cls.registry: + return cls.__pydantic_generate_base_schema__(handler) - return core_schema.tagged_union_schema( - choices=choices, - discriminator=cls.schema_discriminator, - ) + choices = { + name: handler(model_class) for name, model_class in cls.registry.items() + } - return handler(cls) + return core_schema.tagged_union_schema( + choices=choices, + discriminator=cls.schema_discriminator, + ) @classmethod @abstractmethod - def __pydantic_schema_base_type__(cls) -> type[BaseModelT]: + def __pydantic_schema_base_name__(cls) -> str: """ - Define the base type for polymorphic validation hierarchy. + Define the name of the base type for polymorphic validation hierarchy. Must be implemented by subclasses to specify which type serves as the root of the polymorphic hierarchy for schema generation and validation. - :return: Base class type for the polymorphic model hierarchy + :return: Base class name for the polymorphic model hierarchy """ ... diff --git a/tests/unit/utils/test_pydantic_utils.py b/tests/unit/utils/test_pydantic_utils.py index 7d22f352..2c8a6e57 100644 --- a/tests/unit/utils/test_pydantic_utils.py +++ b/tests/unit/utils/test_pydantic_utils.py @@ -141,10 +141,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @TestBaseModel.register("test_sub") class TestSubModel(TestBaseModel): @@ -195,10 +193,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @TestBaseModel.register("test_sub") class TestSubModel(TestBaseModel): @@ -223,10 +219,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @TestBaseModel.register("test_sub") class TestSubModel(TestBaseModel): @@ -245,10 +239,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @TestBaseModel.register() class TestSubModel(TestBaseModel): @@ -268,10 +260,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @TestBaseModel.register("custom_name") class TestSubModel(TestBaseModel): @@ -291,10 +281,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" class RegularClass: pass @@ -314,10 +302,8 @@ class TestBaseModel(PydanticClassRegistryMixin): registry_auto_discovery: ClassVar[bool] = True @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" with ( mock.patch.object(TestBaseModel, "reload_schema") as mock_reload, @@ -340,10 +326,8 @@ class TestBaseModel(PydanticClassRegistryMixin): registry_auto_discovery: ClassVar[bool] = False @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @TestBaseModel.register("test_sub_a") class TestSubModelA(TestBaseModel): @@ -372,10 +356,8 @@ class TestBaseModel(PydanticClassRegistryMixin): registry_auto_discovery: ClassVar[bool] = True @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" with mock.patch.object( TestBaseModel, "auto_populate_registry" @@ -398,10 +380,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" # Ensure registry is None TestBaseModel.registry = None # type: ignore[misc] @@ -443,10 +423,8 @@ class TestBaseModel(PydanticClassRegistryMixin): test_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @classmethod def __pydantic_generate_base_schema__(cls, handler): @@ -508,11 +486,8 @@ class TestBaseModel(PydanticClassRegistryMixin): model_type: str @classmethod - def __pydantic_schema_base_type__(cls) -> type[TestBaseModel]: - if cls.__name__ == "TestBaseModel": - return cls - - return TestBaseModel + def __pydantic_schema_base_name__(cls) -> str: + return "TestBaseModel" @TestBaseModel.register("documented_model") class DocumentedModel(TestBaseModel): From 7ca9fd42d74c398d9ae70d84d56bc0a510b4b69c Mon Sep 17 00:00:00 2001 From: Mark Kurtz Date: Mon, 15 Sep 2025 13:08:51 -0400 Subject: [PATCH 8/8] Remove unused, older reload_* pathways Signed-off-by: Mark Kurtz --- src/speculators/__init__.py | 9 +-------- src/speculators/config.py | 12 ------------ src/speculators/model.py | 15 --------------- tests/unit/test_config.py | 9 --------- tests/unit/test_model.py | 7 ------- 5 files changed, 1 insertion(+), 51 deletions(-) diff --git a/src/speculators/__init__.py b/src/speculators/__init__.py index a5f8bf28..529db76b 100644 --- a/src/speculators/__init__.py +++ b/src/speculators/__init__.py @@ -25,9 +25,8 @@ SpeculatorsConfig, TokenProposalConfig, VerifierConfig, - reload_and_populate_configs, ) -from .model import SpeculatorModel, reload_and_populate_models +from .model import SpeculatorModel __all__ = [ "SpeculatorModel", @@ -35,10 +34,4 @@ "SpeculatorsConfig", "TokenProposalConfig", "VerifierConfig", - "reload_and_populate_configs", - "reload_and_populate_models", ] - -# base imports complete, run auto loading for base classes -reload_and_populate_configs() -reload_and_populate_models() diff --git a/src/speculators/config.py b/src/speculators/config.py index c4a55f46..adf6fc55 100644 --- a/src/speculators/config.py +++ b/src/speculators/config.py @@ -32,7 +32,6 @@ "SpeculatorsConfig", "TokenProposalConfig", "VerifierConfig", - "reload_and_populate_configs", ] @@ -322,14 +321,3 @@ def to_diff_dict(self) -> dict[str, Any]: or set, along with all Pydantic fields. """ return super().to_diff_dict() - - -def reload_and_populate_configs(): - """ - Automatically populates the registry for all PydanticClassRegistryMixin subclasses - and reloads schemas for all Config classes to ensure their schemas are up-to-date - with the current registry state. - """ - TokenProposalConfig.auto_populate_registry() - SpeculatorsConfig.reload_schema() - SpeculatorModelConfig.auto_populate_registry() diff --git a/src/speculators/model.py b/src/speculators/model.py index ebe79554..3450c473 100644 --- a/src/speculators/model.py +++ b/src/speculators/model.py @@ -14,10 +14,6 @@ Classes: SpeculatorModel: Abstract base class for all speculator models with transformers compatibility, automatic registry support, and speculative generation methods - -Functions: - reload_and_populate_models: Automatically populates the model registry for - discovery and instantiation of registered speculator models """ import os @@ -561,14 +557,3 @@ def generate( raise NotImplementedError( "The generate method for speculator models is not implemented yet." ) - - -def reload_and_populate_models(): - """ - Triggers the automatic discovery and registration of all - SpeculatorModel subclasses found in the speculators.models package - that have been registered with `SpeculatorModel.register(NAME)`. This - enables dynamic model loading and instantiation based on configuration - types without requiring explicit imports. - """ - SpeculatorModel.auto_populate_registry() diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index dbe026e9..fd95da39 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -17,7 +17,6 @@ SpeculatorsConfig, TokenProposalConfig, VerifierConfig, - reload_and_populate_configs, ) # ===== TokenProposalConfig Tests ===== @@ -29,10 +28,6 @@ class TokenProposalConfigTest(TokenProposalConfig): test_field: int = 123 -# Ensure the schemas are reloaded to include the test proposal type -reload_and_populate_configs() - - @pytest.mark.smoke def test_token_proposal_config_initialization(): config: TokenProposalConfigTest = TokenProposalConfig( # type: ignore[assignment] @@ -239,10 +234,6 @@ class SpeculatorModelConfigTest(SpeculatorModelConfig): test_field: int = 456 -# Ensure the schemas are reloaded to include the test proposal type -reload_and_populate_configs() - - @pytest.fixture def sample_speculators_config(sample_token_proposal_config, sample_verifier_config): return SpeculatorsConfig( diff --git a/tests/unit/test_model.py b/tests/unit/test_model.py index acca4bca..a222cefc 100644 --- a/tests/unit/test_model.py +++ b/tests/unit/test_model.py @@ -17,8 +17,6 @@ SpeculatorModelConfig, SpeculatorsConfig, VerifierConfig, - reload_and_populate_configs, - reload_and_populate_models, ) from speculators.proposals import GreedyTokenProposalConfig @@ -58,11 +56,6 @@ def forward(self, *args, **kwargs): return {"logits": torch.randn(1, 10, 1000)} -# Reload registries to include test classes -reload_and_populate_configs() -reload_and_populate_models() - - @pytest.fixture def speculator_model_test_config(): return SpeculatorModelTestConfig(