diff --git a/README.md b/README.md index 8690d06bf..9b2aeba45 100644 --- a/README.md +++ b/README.md @@ -234,6 +234,7 @@ guidellm benchmark \ --warmup 0.1 \ --cooldown 0.1 \ --max-errors 5 + --detect-saturation ``` **Key parameters:** @@ -243,6 +244,7 @@ guidellm benchmark \ - `--max-seconds`: Maximum duration in seconds for each benchmark before automatic termination - `--max-requests`: Maximum number of requests per benchmark before automatic termination - `--max-errors`: Maximum number of individual errors before stopping the benchmark entirely +- `--detect-saturation`: Enable over-saturation detection to automatically stop benchmarks when the model becomes over-saturated (see also `--over-saturation` for more advanced control) ## Development and Contribution diff --git a/docs/guides/index.md b/docs/guides/index.md index a362dad7a..ddc9ad3a1 100644 --- a/docs/guides/index.md +++ b/docs/guides/index.md @@ -60,4 +60,12 @@ Whether you're interested in understanding the system architecture, exploring su [:octicons-arrow-right-24: SLO Guide](service_level_objectives.md) +- :material-stop-circle-outline:{ .lg .middle } Over-Saturation Stopping + + ______________________________________________________________________ + + Automatically detect and stop benchmarks when models become over-saturated to prevent wasted compute resources and ensure valid results. + + [:octicons-arrow-right-24: Over-Saturation Guide](over_saturation_stopping.md) + diff --git a/docs/guides/over_saturation_stopping.md b/docs/guides/over_saturation_stopping.md new file mode 100644 index 000000000..68c6e4c82 --- /dev/null +++ b/docs/guides/over_saturation_stopping.md @@ -0,0 +1,138 @@ +# Over-Saturation Stopping + +GuideLLM provides over-saturation detection (OSD) to automatically stop benchmarks when a model becomes over-saturated. This feature helps prevent wasted compute resources and ensures that benchmark results remain valid by detecting when the response rate can no longer keep up with the request rate. + +## What is Over-Saturation? + +Over-saturation occurs when an LLM inference server receives requests faster than it can process them, causing a queue to build up. As the queue grows, the server takes progressively longer to start handling each request, leading to degraded performance metrics. When a performance benchmarking tool oversaturates an LLM inference server, the metrics it measures become significantly skewed, rendering them useless. + +Think of it like a cashier getting flustered during a sudden rush. As the line grows (the load), the cashier can't keep up, the line gets longer, and there is no room for additional customers. This waste of costly machine time can be prevented by automatically detecting and stopping benchmarks when over-saturation is detected. + +## How It Works + +GuideLLM's Over-Saturation Detection (OSD) algorithm uses statistical slope detection to identify when a model becomes over-saturated. The algorithm tracks two key metrics over time: + +1. **Concurrent Requests**: The number of requests being processed simultaneously +2. **Time-to-First-Token (TTFT)**: The latency for the first token of each response + +For each metric, the algorithm: + +- Maintains a sliding window of recent data points +- Calculates the linear regression slope using online statistics +- Computes the margin of error (MOE) using t-distribution confidence intervals +- Detects positive slopes with low MOE, indicating degradation + +Over-saturation is detected when: + +- Both concurrent requests and TTFT show statistically significant positive slopes +- The minimum duration threshold has been met +- Sufficient data points are available for reliable slope estimation + +When over-saturation is detected, the constraint automatically stops request queuing and optionally stops processing of existing requests, preventing further resource waste. + +## Usage + +### Basic Usage + +Enable over-saturation detection with default settings: + +```bash +guidellm benchmark \ + --target http://localhost:8000 \ + --profile throughput \ + --rate 10 \ + --detect-saturation +``` + +### Advanced Configuration + +Configure detection parameters using a JSON dictionary: + +```bash +guidellm benchmark \ + --target http://localhost:8000 \ + --profile concurrent \ + --rate 16 \ + --over-saturation '{"enabled": true, "min_seconds": 60, "max_window_seconds": 300, "moe_threshold": 1.5}' +``` + +## Configuration Options + +The following parameters can be configured when enabling over-saturation detection: + +- **`enabled`** (bool, default: `True`): Whether to stop the benchmark if over-saturation is detected +- **`min_seconds`** (float, default: `30.0`): Minimum seconds before checking for over-saturation. This prevents false positives during the initial warm-up phase. +- **`max_window_seconds`** (float, default: `120.0`): Maximum time window in seconds for data retention. Older data points are automatically pruned to maintain bounded memory usage. +- **`moe_threshold`** (float, default: `2.0`): Margin of error threshold for slope detection. Lower values make detection more sensitive to degradation. +- **`minimum_ttft`** (float, default: `2.5`): Minimum TTFT threshold in seconds for violation counting. Only TTFT values above this threshold are counted as violations. +- **`maximum_window_ratio`** (float, default: `0.75`): Maximum window size as a ratio of total requests. Limits memory usage by capping the number of tracked requests. +- **`minimum_window_size`** (int, default: `5`): Minimum data points required for slope estimation. Ensures statistical reliability before making detection decisions. +- **`confidence`** (float, default: `0.95`): Statistical confidence level for t-distribution calculations (0-1). Higher values require stronger evidence before detecting over-saturation. + +## Use Cases + +Over-saturation detection is particularly useful in the following scenarios: + +### Stress Testing and Capacity Planning + +When testing how your system handles increasing load, over-saturation detection automatically stops benchmarks once the system can no longer keep up, preventing wasted compute time on invalid results. + +```bash +guidellm benchmark \ + --target http://localhost:8000 \ + --profile sweep \ + --rate 5 \ + --detect-saturation +``` + +### Cost-Effective Benchmarking + +When running large-scale benchmark matrices across multiple models, GPUs, and configurations, over-saturation detection can significantly reduce costs by stopping invalid runs early. + +### Finding Safe Operating Ranges + +Use over-saturation detection to identify the maximum sustainable throughput for your deployment, helping you set appropriate rate limits and capacity planning targets. + +## Interpreting Results + +When over-saturation detection is enabled, the benchmark output includes metadata about the detection state. This metadata is available in the scheduler action metadata and includes: + +- **`is_over_saturated`** (bool): Whether over-saturation was detected at the time of evaluation +- **`concurrent_slope`** (float): The calculated slope for concurrent requests +- **`concurrent_slope_moe`** (float): The margin of error for the concurrent requests slope +- **`concurrent_n`** (int): The number of data points used for concurrent requests slope calculation +- **`ttft_slope`** (float): The calculated slope for TTFT +- **`ttft_slope_moe`** (float): The margin of error for the TTFT slope +- **`ttft_n`** (int): The number of data points used for TTFT slope calculation +- **`ttft_violations`** (int): The count of TTFT values exceeding the minimum threshold + +These metrics can help you understand why over-saturation was detected and fine-tune the detection parameters if needed. + +## Example: Complete Benchmark with Over-Saturation Detection + +```bash +guidellm benchmark \ + --target http://localhost:8000 \ + --profile concurrent \ + --rate 16 \ + --data "prompt_tokens=256,output_tokens=128" \ + --max-seconds 300 \ + --over-saturation '{"enabled": true, "min_seconds": 30, "max_window_seconds": 120}' \ + --outputs json,html +``` + +This example: + +- Runs a concurrent benchmark with 16 simultaneous requests +- Uses synthetic data with 256 prompt tokens and 128 output tokens +- Enables over-saturation detection with custom timing parameters +- Sets a maximum duration of 300 seconds (as a fallback) +- Outputs results in both JSON and HTML formats + +## Additional Resources + +For more in-depth information about over-saturation detection, including the algorithm development, evaluation metrics, and implementation details, see the following Red Hat Developer blog posts: + +- [Reduce LLM benchmarking costs with oversaturation detection](https://developers.redhat.com/articles/2025/11/18/reduce-llm-benchmarking-costs-oversaturation-detection) - An introduction to the problem of over-saturation and why it matters for LLM benchmarking +- [Defining success: Evaluation metrics and data augmentation for oversaturation detection](https://developers.redhat.com/articles/2025/11/20/oversaturation-detection-evaluation-metrics) - How to evaluate the performance of an OSD algorithm through custom metrics, dataset labeling, and load augmentation techniques +- [Building an oversaturation detector with iterative error analysis](https://developers.redhat.com/articles/2025/11/24/building-oversaturation-detector-iterative-error-analysis) - A detailed walkthrough of how the OSD algorithm was built diff --git a/src/guidellm/__main__.py b/src/guidellm/__main__.py index b5f918ae5..d0fc89a19 100644 --- a/src/guidellm/__main__.py +++ b/src/guidellm/__main__.py @@ -384,7 +384,27 @@ def benchmark(): default=BenchmarkGenerativeTextArgs.get_default("max_global_error_rate"), help="Maximum global error rate across all benchmarks.", ) -def run(**kwargs): +@click.option( + "--over-saturation", + "over_saturation", + callback=cli_tools.parse_json, + default=None, + help=( + "Enable over-saturation detection. " + "Pass a JSON dict with configuration " + '(e.g., \'{"enabled": true, "min_seconds": 30}\'). ' + "Defaults to None (disabled)." + ), +) +@click.option( + "--detect-saturation", + "--default-over-saturation", + "over_saturation", + callback=cli_tools.parse_json, + flag_value='{"enabled": true}', + help="Enable over-saturation detection with default settings.", +) +def run(**kwargs): # noqa: C901 # Only set CLI args that differ from click defaults kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs) diff --git a/src/guidellm/benchmark/entrypoints.py b/src/guidellm/benchmark/entrypoints.py index 5b57b22fe..75c8c787b 100644 --- a/src/guidellm/benchmark/entrypoints.py +++ b/src/guidellm/benchmark/entrypoints.py @@ -323,6 +323,7 @@ async def resolve_profile( max_errors: int | None, max_error_rate: float | None, max_global_error_rate: float | None, + over_saturation: dict[str, Any] | None = None, console: Console | None = None, ) -> Profile: """ @@ -343,6 +344,7 @@ async def resolve_profile( :param max_errors: Maximum number of errors before stopping :param max_error_rate: Maximum error rate threshold before stopping :param max_global_error_rate: Maximum global error rate threshold before stopping + :param over_saturation: Over-saturation detection configuration (dict) :param console: Console instance for progress reporting, or None :return: Configured Profile instance ready for benchmarking :raises ValueError: If constraints are provided with a pre-configured Profile @@ -359,6 +361,7 @@ async def resolve_profile( "max_errors": max_errors, "max_error_rate": max_error_rate, "max_global_error_rate": max_global_error_rate, + "over_saturation": over_saturation, }.items(): if val is not None: constraints[key] = val @@ -500,6 +503,7 @@ async def benchmark_generative_text( max_errors=args.max_errors, max_error_rate=args.max_error_rate, max_global_error_rate=args.max_global_error_rate, + over_saturation=args.over_saturation, console=console, ) output_formats = await resolve_output_formats( diff --git a/src/guidellm/benchmark/progress.py b/src/guidellm/benchmark/progress.py index d7372a40c..bf744dd22 100644 --- a/src/guidellm/benchmark/progress.py +++ b/src/guidellm/benchmark/progress.py @@ -12,7 +12,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from datetime import datetime from typing import Any, Generic, Literal from rich.console import Group @@ -37,7 +36,7 @@ GenerativeBenchmarkAccumulator, ) from guidellm.scheduler import SchedulerState, SchedulingStrategy -from guidellm.utils import Colors, format_value_display +from guidellm.utils import Colors, format_value_display, safe_format_timestamp __all__ = ["BenchmarkerProgress", "GenerativeConsoleBenchmarkerProgress"] @@ -390,7 +389,7 @@ def formatted_start_time(self) -> str: if self.start_time < 0.0: return "--:--:--" - return datetime.fromtimestamp(self.start_time).strftime("%H:%M:%S") + return safe_format_timestamp(self.start_time, format_="%H:%M:%S") @property def formatted_progress_status(self) -> str: diff --git a/src/guidellm/benchmark/schemas/generative/entrypoints.py b/src/guidellm/benchmark/schemas/generative/entrypoints.py index a080daa03..fff2bec37 100644 --- a/src/guidellm/benchmark/schemas/generative/entrypoints.py +++ b/src/guidellm/benchmark/schemas/generative/entrypoints.py @@ -283,6 +283,14 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any: max_global_error_rate: float | None = Field( default=None, description="Maximum global error rate (0-1) before stopping" ) + over_saturation: dict[str, Any] | None = Field( + default=None, + description=( + "Over-saturation detection configuration. A dict with configuration " + "parameters (enabled, min_seconds, max_window_seconds, " + "moe_threshold, etc.)." + ), + ) @field_validator("data", "data_args", "rate", mode="wrap") @classmethod diff --git a/src/guidellm/scheduler/__init__.py b/src/guidellm/scheduler/__init__.py index c03410767..ab4aeef7b 100644 --- a/src/guidellm/scheduler/__init__.py +++ b/src/guidellm/scheduler/__init__.py @@ -19,6 +19,8 @@ MaxErrorsConstraint, MaxGlobalErrorRateConstraint, MaxNumberConstraint, + OverSaturationConstraint, + OverSaturationConstraintInitializer, PydanticConstraintInitializer, SerializableConstraintInitializer, UnserializableConstraintInitializer, @@ -66,6 +68,8 @@ "MaxNumberConstraint", "MultiTurnRequestT", "NonDistributedEnvironment", + "OverSaturationConstraint", + "OverSaturationConstraintInitializer", "PydanticConstraintInitializer", "RequestT", "ResponseT", diff --git a/src/guidellm/scheduler/constraints.py b/src/guidellm/scheduler/constraints.py deleted file mode 100644 index 21e0fe967..000000000 --- a/src/guidellm/scheduler/constraints.py +++ /dev/null @@ -1,1037 +0,0 @@ -""" -Constraint system for scheduler behavior control and request processing limits. - -Provides flexible constraints for managing scheduler behavior with configurable -thresholds based on time, error rates, and request counts. Constraints evaluate -scheduler state and individual requests to determine whether processing should -continue or stop based on predefined limits. The constraint system enables -sophisticated benchmark stopping criteria through composable constraint types. -""" - -from __future__ import annotations - -import time -from abc import ABC, abstractmethod -from typing import Any, Literal, Protocol, cast, runtime_checkable - -from pydantic import Field, field_validator - -from guidellm.scheduler.schemas import ( - SchedulerProgress, - SchedulerState, - SchedulerUpdateAction, -) -from guidellm.schemas import RequestInfo, StandardBaseModel -from guidellm.settings import settings -from guidellm.utils import InfoMixin, RegistryMixin - -__all__ = [ - "Constraint", - "ConstraintInitializer", - "ConstraintsInitializerFactory", - "MaxDurationConstraint", - "MaxErrorRateConstraint", - "MaxErrorsConstraint", - "MaxGlobalErrorRateConstraint", - "MaxNumberConstraint", - "PydanticConstraintInitializer", - "RequestsExhaustedConstraint", - "SerializableConstraintInitializer", - "UnserializableConstraintInitializer", -] - - -@runtime_checkable -class Constraint(Protocol): - """Protocol for constraint evaluation functions that control scheduler behavior.""" - - def __call__( - self, state: SchedulerState, request: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against scheduler state and request information. - - :param state: Current scheduler state with metrics and timing information - :param request: Individual request information and metadata - :return: Action indicating whether to continue or stop scheduler operations - """ - - -@runtime_checkable -class ConstraintInitializer(Protocol): - """Protocol for constraint initializer factory functions that create constraints.""" - - def create_constraint(self, **kwargs) -> Constraint: - """ - Create a constraint instance from configuration parameters. - - :param kwargs: Configuration parameters for constraint creation - :return: Configured constraint evaluation function - """ - - -@runtime_checkable -class SerializableConstraintInitializer(Protocol): - """Protocol for serializable constraint initializers supporting persistence.""" - - @classmethod - def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - """ - Validate and process arguments for constraint creation. - - :param args: Positional arguments for constraint configuration - :param kwargs: Keyword arguments for constraint configuration - :return: Validated parameter dictionary for constraint creation - """ - - @classmethod - def model_validate(cls, **kwargs) -> ConstraintInitializer: - """ - Create validated constraint initializer from configuration. - - :param kwargs: Configuration dictionary for initializer creation - :return: Validated constraint initializer instance - """ - - def model_dump(self) -> dict[str, Any]: - """ - Serialize constraint initializer to dictionary format. - - :return: Dictionary representation of constraint initializer - """ - - def create_constraint(self, **kwargs) -> Constraint: - """ - Create constraint instance from this initializer. - - :param kwargs: Additional configuration parameters - :return: Configured constraint evaluation function - """ - - -class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): - """ - Registry factory for creating and managing constraint initializers. - - Provides centralized access to registered constraint types with support for - creating constraints from configuration dictionaries, simple values, or - pre-configured instances. Handles constraint resolution and type validation - for the scheduler constraint system. - - Example: - :: - from guidellm.scheduler import ConstraintsInitializerFactory - - # Register new constraint type - @ConstraintsInitializerFactory.register("new_constraint") - class NewConstraint: - def create_constraint(self, **kwargs) -> Constraint: - return lambda state, request: SchedulerUpdateAction() - - # Create and use constraint - constraint = ConstraintsInitializerFactory.create_constraint("new_constraint") - """ - - @classmethod - def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: - """ - Create a constraint initializer for the specified key. - - :param key: Registered constraint initializer key - :param args: Positional arguments for initializer creation - :param kwargs: Keyword arguments for initializer creation - :return: Configured constraint initializer instance - :raises ValueError: If the key is not registered in the factory - """ - if cls.registry is None or key not in cls.registry: - raise ValueError(f"Unknown constraint initializer key: {key}") - - initializer_class = cls.registry[key] - - return ( - initializer_class(*args, **kwargs) # type: ignore[operator] - if not isinstance(initializer_class, type) - or not issubclass(initializer_class, SerializableConstraintInitializer) - else initializer_class( - **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc] - ) - ) - - @classmethod - def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: - """ - Serialize constraint initializer to dictionary format. - - :param initializer: Constraint initializer to serialize - :return: Dictionary representation or unserializable placeholder - """ - if isinstance(initializer, SerializableConstraintInitializer): - return initializer.model_dump() - else: - unserializable = UnserializableConstraintInitializer( - orig_info=InfoMixin.extract_from_obj(initializer) - ) - return unserializable.model_dump() - - @classmethod - def deserialize( - cls, initializer_dict: dict[str, Any] - ) -> SerializableConstraintInitializer | UnserializableConstraintInitializer: - """ - Deserialize constraint initializer from dictionary format. - - :param initializer_dict: Dictionary representation of constraint initializer - :return: Reconstructed constraint initializer instance - :raises ValueError: If constraint type is unknown or cannot be deserialized - """ - if initializer_dict.get("type_") == "unserializable": - return UnserializableConstraintInitializer.model_validate(initializer_dict) - - if ( - cls.registry is not None - and initializer_dict.get("type_") - and initializer_dict["type_"] in cls.registry - ): - initializer_class = cls.registry[initializer_dict["type_"]] - if hasattr(initializer_class, "model_validate"): - return initializer_class.model_validate(initializer_dict) # type: ignore[return-value] - else: - return initializer_class(**initializer_dict) # type: ignore[return-value,operator] - - raise ValueError( - f"Cannot deserialize unknown constraint initializer: " - f"{initializer_dict.get('type_', 'unknown')}" - ) - - @classmethod - def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: - """ - Create a constraint instance for the specified key. - - :param key: Registered constraint initializer key - :param args: Positional arguments for constraint creation - :param kwargs: Keyword arguments for constraint creation - :return: Configured constraint function ready for evaluation - :raises ValueError: If the key is not registered in the factory - """ - return cls.create(key, *args, **kwargs).create_constraint() - - @classmethod - def resolve( - cls, - initializers: dict[ - str, - Any | dict[str, Any] | Constraint | ConstraintInitializer, - ], - ) -> dict[str, Constraint]: - """ - Resolve mixed constraint specifications to callable constraints. - - :param initializers: Dictionary mapping constraint keys to specifications - :return: Dictionary mapping constraint keys to callable functions - :raises ValueError: If any key is not registered in the factory - """ - constraints = {} - - for key, val in initializers.items(): - if isinstance(val, Constraint): - constraints[key] = val - elif isinstance(val, ConstraintInitializer): - constraints[key] = val.create_constraint() - elif isinstance(val, dict): - constraints[key] = cls.create_constraint(key, **val) - else: - constraints[key] = cls.create_constraint(key, val) - - return constraints - - @classmethod - def resolve_constraints( - cls, - constraints: dict[str, Any | dict[str, Any] | Constraint], - ) -> dict[str, Constraint]: - """ - Resolve constraints from mixed constraint specifications. - - :param constraints: Dictionary mapping constraint keys to specifications - :return: Dictionary mapping constraint keys to callable functions - :raises ValueError: If any constraint key is not registered - """ - resolved_constraints = {} - - for key, val in constraints.items(): - if isinstance(val, Constraint): - resolved_constraints[key] = val - elif isinstance(val, dict): - resolved_constraints[key] = cls.create_constraint(key, **val) - else: - resolved_constraints[key] = cls.create_constraint(key, val) - - return resolved_constraints - - -class PydanticConstraintInitializer(StandardBaseModel, ABC, InfoMixin): - """ - Abstract base for Pydantic-based constraint initializers. - - Provides standardized serialization, validation, and metadata handling for - constraint initializers using Pydantic models. Subclasses implement specific - constraint creation logic while inheriting validation and persistence support. - """ - - type_: str = Field(description="Type identifier for the constraint initializer") - - @property - def info(self) -> dict[str, Any]: - """ - Extract serializable information from this constraint initializer. - - :return: Dictionary containing constraint configuration and metadata - """ - return self.model_dump() - - @classmethod - @abstractmethod - def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: - """ - Validate and process arguments for constraint creation. - - Must be implemented by subclasses to handle their specific parameter patterns - and validation requirements. - - :param args: Positional arguments passed to the constraint - :param kwargs: Keyword arguments passed to the constraint - :return: Validated dictionary of parameters for constraint creation - :raises NotImplementedError: Must be implemented by subclasses - """ - ... - - @abstractmethod - def create_constraint(self, **kwargs) -> Constraint: - """ - Create a constraint instance. - - Must be implemented by subclasses to return their specific constraint type - with appropriate configuration and validation. - - :param kwargs: Additional keyword arguments (usually unused) - :return: Configured constraint instance - :raises NotImplementedError: Must be implemented by subclasses - """ - ... - - -class UnserializableConstraintInitializer(PydanticConstraintInitializer): - """ - Placeholder for constraints that cannot be serialized or executed. - - Represents constraint initializers that failed serialization or contain - non-serializable components. Cannot be executed and raises errors when - invoked to prevent runtime failures from invalid constraint state. - """ - - type_: Literal["unserializable"] = "unserializable" # type: ignore[assignment] - orig_info: dict[str, Any] = Field( - default_factory=dict, - description="Original constraint information before serialization failure", - ) - - @classmethod - def validated_kwargs( - cls, orig_info: dict[str, Any] | None = None, **_kwargs - ) -> dict[str, Any]: - """ - Validate arguments for unserializable constraint creation. - - :param orig_info: Original constraint information before serialization failure - :param kwargs: Additional arguments (ignored) - :return: Validated parameters for unserializable constraint creation - """ - return {"orig_info": orig_info or {}} - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Raise error for unserializable constraint creation attempt. - - :param kwargs: Additional keyword arguments (unused) - :raises RuntimeError: Always raised since unserializable constraints - cannot be executed - """ - raise RuntimeError( - "Cannot create constraint from unserializable constraint instance. " - "This constraint cannot be serialized and therefore cannot be executed." - ) - - def __call__( - self, state: SchedulerState, request: RequestInfo - ) -> SchedulerUpdateAction: - """ - Raise error since unserializable constraints cannot be invoked. - - :param state: Current scheduler state (unused) - :param request: Individual request information (unused) - :raises RuntimeError: Always raised for unserializable constraints - """ - _ = (state, request) # Unused parameters - raise RuntimeError( - "Cannot invoke unserializable constraint instance. " - "This constraint was not properly serialized and cannot be executed." - ) - - -@ConstraintsInitializerFactory.register( # type: ignore[arg-type] - ["max_number", "max_num", "max_requests", "max_req"] -) -class MaxNumberConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on maximum request counts. - - Stops request queuing when created requests reach the limit and stops local - request processing when processed requests reach the limit. Provides progress - tracking based on remaining requests and completion fraction. - """ - - type_: Literal["max_number"] = "max_number" # type: ignore[assignment] - max_num: int | float | list[int | float] = Field( - description="Maximum number of requests allowed before triggering constraint", - ) - current_index: int = Field( - default=-1, description="Current index for list-based max_num values" - ) - - @classmethod - def validated_kwargs( - cls, max_num: int | float | list[int | float], **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxNumberConstraint creation. - - :param max_num: Maximum number of requests to allow - :param kwargs: Supports max_num, max_number, max_requests, max_req, - and optional type_ - :return: Validated dictionary with max_num and type_ fields - """ - aliases = ["max_number", "max_num", "max_requests", "max_req"] - for alias in aliases: - if max_num is None: - max_num = kwargs.get(alias) - - return {"max_num": max_num, "current_index": kwargs.get("current_index", -1)} - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against current scheduler state and request count. - - :param state: Current scheduler state with request counts - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_num = ( - self.max_num - if isinstance(self.max_num, int | float) - else self.max_num[min(current_index, len(self.max_num) - 1)] - ) - - create_exceeded = state.created_requests >= max_num - processed_exceeded = state.processed_requests >= max_num - remaining_requests = min(max(0, max_num - state.processed_requests), max_num) - stop_time = ( - None if remaining_requests > 0 else request_info.completed_at or time.time() - ) - - return SchedulerUpdateAction( - request_queuing="stop" if create_exceeded else "continue", - request_processing="stop_local" if processed_exceeded else "continue", - metadata={ - "max_number": max_num, - "create_exceeded": create_exceeded, - "processed_exceeded": processed_exceeded, - "created_requests": state.created_requests, - "processed_requests": state.processed_requests, - "remaining_requests": remaining_requests, - "stop_time": stop_time, - }, - progress=SchedulerProgress( - remaining_requests=remaining_requests, - total_requests=max_num, - stop_time=stop_time, - ), - ) - - @field_validator("max_num") - @classmethod - def _validate_max_num( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - f"max_num must be set and truthful, received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0: - raise ValueError( - f"max_num must be a positive num, received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"] -) -class MaxDurationConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on maximum time duration. - - Stops both request queuing and processing when the elapsed time since scheduler - start exceeds the maximum duration. Provides progress tracking based on - remaining time and completion fraction. - """ - - type_: Literal["max_duration"] = "max_duration" # type: ignore[assignment] - max_duration: int | float | list[int | float] = Field( - description="Maximum duration in seconds before triggering constraint" - ) - current_index: int = Field(default=-1, description="Current index in duration list") - - @classmethod - def validated_kwargs( - cls, max_duration: int | float | list[int | float] | None = None, **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxDurationConstraint creation. - - :param max_duration: Maximum duration in seconds - :param kwargs: Supports max_duration, max_dur, max_sec, max_seconds, - max_min, max_minutes, and optional type_ - :return: Validated dictionary with max_duration and type_ fields - """ - seconds_aliases = ["max_dur", "max_sec", "max_seconds"] - for alias in seconds_aliases: - if max_duration is None: - max_duration = kwargs.get(alias) - minutes_aliases = ["max_min", "max_minutes"] - for alias in minutes_aliases: - minutes = kwargs.get(alias) - if minutes is not None and max_duration is None: - max_duration = minutes * 60 - - return { - "max_duration": max_duration, - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against current scheduler state and elapsed time. - - :param state: Current scheduler state with start time - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_duration = ( - self.max_duration - if isinstance(self.max_duration, int | float) - else self.max_duration[min(current_index, len(self.max_duration) - 1)] - ) - - current_time = time.time() - elapsed = current_time - state.start_time - duration_exceeded = elapsed >= max_duration - remaining_duration = min(max(0.0, max_duration - elapsed), max_duration) - stop_time = None if not duration_exceeded else state.start_time + max_duration - - return SchedulerUpdateAction( - request_queuing="stop" if duration_exceeded else "continue", - request_processing="stop_local" if duration_exceeded else "continue", - metadata={ - "max_duration": max_duration, - "elapsed_time": elapsed, - "duration_exceeded": duration_exceeded, - "start_time": state.start_time, - "current_time": current_time, - "stop_time": stop_time, - }, - progress=SchedulerProgress( - remaining_duration=remaining_duration, - total_duration=max_duration, - stop_time=stop_time, - ), - ) - - @field_validator("max_duration") - @classmethod - def _validate_max_duration( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_duration must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0: - raise ValueError( - "max_duration must be a positive num," - f"received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_errors", "max_err", "max_error", "max_errs"] -) -class MaxErrorsConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on absolute error count. - - Stops both request queuing and all request processing when the total number - of errored requests reaches the maximum threshold. Uses global error tracking - across all requests for immediate constraint evaluation. - """ - - type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment] - max_errors: int | float | list[int | float] = Field( - description="Maximum number of errors allowed before triggering constraint", - ) - current_index: int = Field(default=-1, description="Current index in error list") - - @classmethod - def validated_kwargs( - cls, max_errors: int | float | list[int | float] | None = None, **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxErrorsConstraint creation. - - :param max_errors: Maximum number of errors to allow - :param kwargs: Supports max_errors, max_err, max_error, max_errs, - and optional type_ - :return: Validated dictionary with max_errors and type_ fields - """ - aliases = ["max_errors", "max_err", "max_error", "max_errs"] - for alias in aliases: - if max_errors is None: - max_errors = kwargs.get(alias) - - return { - "max_errors": max_errors, - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against current error count. - - :param state: Current scheduler state with error counts - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_errors = ( - self.max_errors - if isinstance(self.max_errors, int | float) - else self.max_errors[min(current_index, len(self.max_errors) - 1)] - ) - errors_exceeded = state.errored_requests >= max_errors - stop_time = ( - None if not errors_exceeded else request_info.completed_at or time.time() - ) - - return SchedulerUpdateAction( - request_queuing="stop" if errors_exceeded else "continue", - request_processing="stop_all" if errors_exceeded else "continue", - metadata={ - "max_errors": max_errors, - "errors_exceeded": errors_exceeded, - "current_errors": state.errored_requests, - "stop_time": stop_time, - }, - progress=SchedulerProgress(stop_time=stop_time), - ) - - @field_validator("max_errors") - @classmethod - def _validate_max_errors( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_errors must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0: - raise ValueError( - f"max_errors must be a positive num,received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_error_rate", "max_err_rate", "max_errors_rate"] -) -class MaxErrorRateConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on sliding window error rate. - - Tracks error status of recent requests in a sliding window and stops all - processing when the error rate exceeds the threshold. Only applies the - constraint after processing enough requests to fill the minimum window size - for statistical significance. - """ - - type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment] - max_error_rate: int | float | list[int | float] = Field( - description="Maximum error rate allowed (0.0, 1.0)" - ) - window_size: int | float = Field( - default=30, - gt=0, - description="Size of sliding window for calculating error rate", - ) - error_window: list[bool] = Field( - default_factory=list, - description="Sliding window tracking error status of recent requests", - ) - current_index: int = Field( - default=-1, description="Current index in the error window" - ) - - @classmethod - def validated_kwargs( - cls, max_error_rate: int | float | list[int | float], **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxErrorRateConstraint creation. - - :param max_error_rate: Maximum error rate to allow - :param kwargs: Supports max_error_rate, max_err_rate, max_errors_rate, - optional window_size, and optional type_ - :return: Validated dictionary with max_error_rate, window_size, - and type_ fields - """ - aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] - for alias in aliases: - if max_error_rate is None: - max_error_rate = kwargs.get(alias) - - return { - "max_error_rate": max_error_rate, - "window_size": kwargs.get( - "window_size", settings.constraint_error_window_size - ), - "error_window": kwargs.get("error_window", []), - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Create a new instance of MaxErrorRateConstraint (due to stateful window). - - :param kwargs: Additional keyword arguments (unused) - :return: New instance of the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against sliding window error rate. - - :param state: Current scheduler state with request counts - :param request_info: Individual request with completion status - :return: Action indicating whether to continue or stop operations - """ - current_index = max(0, self.current_index) - max_error_rate = ( - self.max_error_rate - if isinstance(self.max_error_rate, int | float) - else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] - ) - - if request_info.status in ["completed", "errored", "cancelled"]: - self.error_window.append(request_info.status == "errored") - if len(self.error_window) > self.window_size: - self.error_window.pop(0) - - error_count = sum(self.error_window) - window_requests = len(self.error_window) - error_rate = ( - error_count / float(window_requests) if window_requests > 0 else 0.0 - ) - exceeded_min_processed = state.processed_requests >= self.window_size - exceeded_error_rate = error_rate >= max_error_rate - exceeded = exceeded_min_processed and exceeded_error_rate - stop_time = None if not exceeded else request_info.completed_at or time.time() - - return SchedulerUpdateAction( - request_queuing="stop" if exceeded else "continue", - request_processing="stop_all" if exceeded else "continue", - metadata={ - "max_error_rate": max_error_rate, - "window_size": self.window_size, - "error_count": error_count, - "processed_count": state.processed_requests, - "current_window_size": len(self.error_window), - "current_error_rate": error_rate, - "exceeded_min_processed": exceeded_min_processed, - "exceeded_error_rate": exceeded_error_rate, - "exceeded": exceeded, - "stop_time": stop_time, - }, - ) - - @field_validator("max_error_rate") - @classmethod - def _validate_max_error_rate( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_error_rate must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0 or val >= 1: - raise ValueError( - "max_error_rate must be a number between 0 and 1," - f"received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -@ConstraintsInitializerFactory.register( - ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"] -) -class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): - """ - Constraint that limits execution based on global error rate. - - Calculates error rate across all processed requests and stops all processing - when the rate exceeds the threshold. Only applies the constraint after - processing the minimum number of requests to ensure statistical significance - for global error rate calculations. - """ - - type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment] - max_error_rate: int | float = Field( - description="Maximum error rate allowed (0.0 to 1.0)" - ) - min_processed: int | float | None = Field( - default=30, - gt=0, - description="Minimum requests processed before applying error rate constraint", - ) - current_index: int = Field( - default=-1, description="Current index for list-based max_error_rate values" - ) - - @classmethod - def validated_kwargs( - cls, max_error_rate: int | float | list[int | float], **kwargs - ) -> dict[str, Any]: - """ - Validate and process arguments for MaxGlobalErrorRateConstraint creation. - - :param max_error_rate: Maximum error rate to allow - :param kwargs: Supports max_global_error_rate, max_global_err_rate, - max_global_errors_rate, optional min_processed, and optional type_ - :return: Validated dictionary with max_error_rate, min_processed, - and type_ fields - """ - for alias in [ - "max_global_error_rate", - "max_global_err_rate", - "max_global_errors_rate", - ]: - if max_error_rate is None: - max_error_rate = kwargs.get(alias) - - return { - "max_error_rate": max_error_rate, - "min_processed": kwargs.get( - "min_processed", settings.constraint_error_min_processed - ), - "current_index": kwargs.get("current_index", -1), - } - - def create_constraint(self, **_kwargs) -> Constraint: - """ - Return self as the constraint instance. - - :param kwargs: Additional keyword arguments (unused) - :return: Self instance as the constraint - """ - self.current_index += 1 - - return cast("Constraint", self.model_copy()) - - def __call__( - self, state: SchedulerState, request_info: RequestInfo - ) -> SchedulerUpdateAction: - """ - Evaluate constraint against global error rate. - - :param state: Current scheduler state with global request and error counts - :param request_info: Individual request information (unused) - :return: Action indicating whether to continue or stop operations - """ - _ = request_info # Unused parameters - current_index = max(0, self.current_index) - max_error_rate = ( - self.max_error_rate - if isinstance(self.max_error_rate, int | float) - else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] - ) - - exceeded_min_processed = ( - self.min_processed is None or state.processed_requests >= self.min_processed - ) - error_rate = ( - state.errored_requests / float(state.processed_requests) - if state.processed_requests > 0 - else 0.0 - ) - exceeded_error_rate = error_rate >= max_error_rate - exceeded = exceeded_min_processed and exceeded_error_rate - stop_time = None if not exceeded else request_info.completed_at or time.time() - - return SchedulerUpdateAction( - request_queuing="stop" if exceeded else "continue", - request_processing="stop_all" if exceeded else "continue", - metadata={ - "max_error_rate": max_error_rate, - "min_processed": self.min_processed, - "processed_requests": state.processed_requests, - "errored_requests": state.errored_requests, - "error_rate": error_rate, - "exceeded_min_processed": exceeded_min_processed, - "exceeded_error_rate": exceeded_error_rate, - "exceeded": exceeded, - "stop_time": stop_time, - }, - progress=SchedulerProgress(stop_time=stop_time), - ) - - @field_validator("max_error_rate") - @classmethod - def _validate_max_error_rate( - cls, value: int | float | list[int | float] - ) -> int | float | list[int | float]: - if not isinstance(value, list): - value = [value] - for val in value: - if not val: - raise ValueError( - "max_error_rate must be set and truthful, " - f"received {value} ({val} failed)" - ) - if not isinstance(val, int | float) or val <= 0 or val >= 1: - raise ValueError( - "max_error_rate must be a number between 0 and 1," - f"received {value} ({val} failed)" - ) - - return value[0] if isinstance(value, list) and len(value) == 1 else value - - -class RequestsExhaustedConstraint(StandardBaseModel, InfoMixin): - type_: Literal["requests_exhausted"] = "requests_exhausted" # type: ignore[assignment] - num_requests: int - - @property - def info(self) -> dict[str, Any]: - """ - Extract serializable information from this constraint initializer. - - :return: Dictionary containing constraint configuration and metadata - """ - return self.model_dump() - - def __call__( - self, state: SchedulerState, request: RequestInfo - ) -> SchedulerUpdateAction: - _ = request # Unused parameter - create_exceeded = state.created_requests >= self.num_requests - processed_exceeded = state.processed_requests >= self.num_requests - remaining_requests = max(0, self.num_requests - state.processed_requests) - stop_time = ( - None if remaining_requests > 0 else request.completed_at or time.time() - ) - - return SchedulerUpdateAction( - request_queuing="stop" if create_exceeded else "continue", - request_processing="stop_local" if processed_exceeded else "continue", - metadata={ - "num_requests": self.num_requests, - "create_exceeded": create_exceeded, - "processed_exceeded": processed_exceeded, - "created_requests": state.created_requests, - "processed_requests": state.processed_requests, - "remaining_requests": remaining_requests, - "stop_time": stop_time, - }, - progress=SchedulerProgress( - remaining_requests=remaining_requests, - total_requests=self.num_requests, - stop_time=stop_time, - ), - ) diff --git a/src/guidellm/scheduler/constraints/__init__.py b/src/guidellm/scheduler/constraints/__init__.py new file mode 100644 index 000000000..1f5343a93 --- /dev/null +++ b/src/guidellm/scheduler/constraints/__init__.py @@ -0,0 +1,49 @@ +""" +Constraint system for scheduler behavior control and request processing limits. + +Provides flexible constraints for managing scheduler behavior with configurable +thresholds based on time, error rates, and request counts. Constraints evaluate +scheduler state and individual requests to determine whether processing should +continue or stop based on predefined limits. The constraint system enables +sophisticated benchmark stopping criteria through composable constraint types. +""" + +from .constraint import ( + Constraint, + ConstraintInitializer, + PydanticConstraintInitializer, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) +from .error import ( + MaxErrorRateConstraint, + MaxErrorsConstraint, + MaxGlobalErrorRateConstraint, +) +from .factory import ConstraintsInitializerFactory +from .request import ( + MaxDurationConstraint, + MaxNumberConstraint, + RequestsExhaustedConstraint, +) +from .saturation import ( + OverSaturationConstraint, + OverSaturationConstraintInitializer, +) + +__all__ = [ + "Constraint", + "ConstraintInitializer", + "ConstraintsInitializerFactory", + "MaxDurationConstraint", + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", + "MaxNumberConstraint", + "OverSaturationConstraint", + "OverSaturationConstraintInitializer", + "PydanticConstraintInitializer", + "RequestsExhaustedConstraint", + "SerializableConstraintInitializer", + "UnserializableConstraintInitializer", +] diff --git a/src/guidellm/scheduler/constraints/constraint.py b/src/guidellm/scheduler/constraints/constraint.py new file mode 100644 index 000000000..dd901acfa --- /dev/null +++ b/src/guidellm/scheduler/constraints/constraint.py @@ -0,0 +1,325 @@ +""" +Core constraint system protocols and base classes. + +Defines the fundamental protocols and base classes that form the foundation of the +constraint system. Constraints control scheduler behavior by evaluating scheduler +state and individual requests to determine whether processing should continue or +stop based on predefined limits. The constraint system enables sophisticated +benchmark stopping criteria through composable constraint types with support for +serialization, validation, and dynamic instantiation. + +The module provides: +- Protocols defining the constraint interface contract + (Constraint, ConstraintInitializer) +- Base classes for Pydantic-based constraint initializers with serialization support +- Placeholder classes for handling unserializable constraint states + +Example: +:: + from guidellm.scheduler.constraints import ( + Constraint, + PydanticConstraintInitializer, + ) + + class MyConstraint(PydanticConstraintInitializer): + type_: str = "my_constraint" + + def create_constraint(self) -> Constraint: + def evaluate(state, request): + return SchedulerUpdateAction(request_queuing="continue") + return evaluate +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Literal, Protocol, runtime_checkable + +from pydantic import Field + +from guidellm.scheduler.schemas import SchedulerState, SchedulerUpdateAction +from guidellm.schemas import RequestInfo, StandardBaseModel +from guidellm.utils import InfoMixin + +__all__ = [ + "Constraint", + "ConstraintInitializer", + "PydanticConstraintInitializer", + "SerializableConstraintInitializer", + "UnserializableConstraintInitializer", +] + + +@runtime_checkable +class Constraint(Protocol): + """ + Protocol for constraint evaluation functions that control scheduler behavior. + + Defines the interface that all constraint implementations must follow. Constraints + are callable objects that evaluate scheduler state and request information to + determine whether processing should continue or stop. The protocol enables type + checking and runtime validation of constraint implementations while allowing + flexible implementation approaches (functions, classes, closures). + + Example: + :: + def my_constraint( + state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + if state.processing_requests > 100: + return SchedulerUpdateAction(request_queuing="stop") + return SchedulerUpdateAction(request_queuing="continue") + """ + + def __call__( + self, state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against scheduler state and request information. + + :param state: Current scheduler state with metrics and timing information + :param request: Individual request information and metadata + :return: Action indicating whether to continue or stop scheduler operations + """ + + +@runtime_checkable +class ConstraintInitializer(Protocol): + """ + Protocol for constraint initializer factory functions that create constraints. + + Defines the interface for factory objects that create constraint instances from + configuration parameters. Constraint initializers enable dynamic constraint + creation and configuration, supporting both simple boolean flags and complex + parameter dictionaries. The protocol allows type checking while maintaining + flexibility for different initialization patterns. + + Example: + :: + class MaxRequestsInitializer: + def __init__(self, max_requests: int): + self.max_requests = max_requests + + def create_constraint(self) -> Constraint: + def evaluate(state, request): + if state.total_requests >= self.max_requests: + return SchedulerUpdateAction(request_queuing="stop") + return SchedulerUpdateAction(request_queuing="continue") + return evaluate + """ + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance from configuration parameters. + + :param kwargs: Configuration parameters for constraint creation + :return: Configured constraint evaluation function + """ + + +@runtime_checkable +class SerializableConstraintInitializer(Protocol): + """ + Protocol for serializable constraint initializers supporting persistence. + + Extends ConstraintInitializer with serialization capabilities, enabling constraint + configurations to be saved, loaded, and transmitted. Serializable initializers + support validation, model-based configuration, and dictionary-based serialization + for integration with configuration systems and persistence layers. + + Example: + :: + class SerializableInitializer: + @classmethod + def validated_kwargs(cls, **kwargs) -> dict[str, Any]: + return {"max_requests": kwargs.get("max_requests", 100)} + + @classmethod + def model_validate(cls, data: dict) -> ConstraintInitializer: + return cls(**cls.validated_kwargs(**data)) + + def model_dump(self) -> dict[str, Any]: + return {"type_": "max_requests", "max_requests": self.max_requests} + + def create_constraint(self) -> Constraint: + # ... create constraint + """ + + @classmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + :param args: Positional arguments for constraint configuration + :param kwargs: Keyword arguments for constraint configuration + :return: Validated parameter dictionary for constraint creation + """ + + @classmethod + def model_validate(cls, **kwargs) -> ConstraintInitializer: + """ + Create validated constraint initializer from configuration. + + :param kwargs: Configuration dictionary for initializer creation + :return: Validated constraint initializer instance + """ + + def model_dump(self) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :return: Dictionary representation of constraint initializer + """ + + def create_constraint(self, **kwargs) -> Constraint: + """ + Create constraint instance from this initializer. + + :param kwargs: Additional configuration parameters + :return: Configured constraint evaluation function + """ + + +class PydanticConstraintInitializer(StandardBaseModel, ABC, InfoMixin): + """ + Abstract base for Pydantic-based constraint initializers. + + Provides standardized serialization, validation, and metadata handling for + constraint initializers using Pydantic models. Subclasses implement specific + constraint creation logic while inheriting validation and persistence support. + Integrates with the constraint factory system for dynamic instantiation and + configuration management. + + Example: + :: + @ConstraintsInitializerFactory.register("max_duration") + class MaxDurationConstraintInitializer(PydanticConstraintInitializer): + type_: str = "max_duration" + max_seconds: float = Field(description="Maximum duration in seconds") + + def create_constraint(self) -> Constraint: + def evaluate(state, request): + if time.time() - state.start_time > self.max_seconds: + return SchedulerUpdateAction(request_queuing="stop") + return SchedulerUpdateAction(request_queuing="continue") + return evaluate + + :cvar type_: Type identifier for the constraint initializer + """ + + type_: str = Field(description="Type identifier for the constraint initializer") + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + @classmethod + @abstractmethod + def validated_kwargs(cls, *args, **kwargs) -> dict[str, Any]: + """ + Validate and process arguments for constraint creation. + + Must be implemented by subclasses to handle their specific parameter patterns + and validation requirements. This method processes raw input (booleans, dicts, + etc.) and converts them into validated parameter dictionaries suitable for + constraint initialization. + + :param args: Positional arguments passed to the constraint + :param kwargs: Keyword arguments passed to the constraint + :return: Validated dictionary of parameters for constraint creation + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + @abstractmethod + def create_constraint(self, **kwargs) -> Constraint: + """ + Create a constraint instance. + + Must be implemented by subclasses to return their specific constraint type + with appropriate configuration and validation. The returned constraint should + be ready for evaluation against scheduler state and requests. + + :param kwargs: Additional keyword arguments (usually unused) + :return: Configured constraint instance + :raises NotImplementedError: Must be implemented by subclasses + """ + ... + + +class UnserializableConstraintInitializer(PydanticConstraintInitializer): + """ + Placeholder for constraints that cannot be serialized or executed. + + Represents constraint initializers that failed serialization or contain + non-serializable components. Cannot be executed and raises errors when + invoked to prevent runtime failures from invalid constraint state. Used + by the factory system to preserve constraint information even when full + serialization is not possible. + + Example: + :: + # Created automatically by factory when serialization fails + unserializable = UnserializableConstraintInitializer( + orig_info={"type_": "custom", "data": non_serializable_object} + ) + + # Attempting to use it raises RuntimeError + constraint = unserializable.create_constraint() # Raises RuntimeError + + :cvar type_: Always "unserializable" to identify placeholder constraints + :cvar orig_info: Original constraint information before serialization failure + """ + + type_: Literal["unserializable"] = "unserializable" # type: ignore[assignment] + orig_info: dict[str, Any] = Field( + default_factory=dict, + description="Original constraint information before serialization failure", + ) + + @classmethod + def validated_kwargs( + cls, orig_info: dict[str, Any] | None = None, **_kwargs + ) -> dict[str, Any]: + """ + Validate arguments for unserializable constraint creation. + + :param orig_info: Original constraint information before serialization failure + :param kwargs: Additional arguments (ignored) + :return: Validated parameters for unserializable constraint creation + """ + return {"orig_info": orig_info or {}} + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Raise error for unserializable constraint creation attempt. + + :param kwargs: Additional keyword arguments (unused) + :raises RuntimeError: Always raised since unserializable constraints + cannot be executed + """ + raise RuntimeError( + "Cannot create constraint from unserializable constraint instance. " + "This constraint cannot be serialized and therefore cannot be executed." + ) + + def __call__( + self, state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + """ + Raise error since unserializable constraints cannot be invoked. + + :param state: Current scheduler state (unused) + :param request: Individual request information (unused) + :raises RuntimeError: Always raised for unserializable constraints + """ + _ = (state, request) # Unused parameters + raise RuntimeError( + "Cannot invoke unserializable constraint instance. " + "This constraint was not properly serialized and cannot be executed." + ) diff --git a/src/guidellm/scheduler/constraints/error.py b/src/guidellm/scheduler/constraints/error.py new file mode 100644 index 000000000..d9ed7ca95 --- /dev/null +++ b/src/guidellm/scheduler/constraints/error.py @@ -0,0 +1,411 @@ +""" +Error-based constraint implementations. + +Provides constraint types for limiting benchmark execution based on error rates +and error counts. These constraints monitor request error status to determine +when to stop benchmark execution due to excessive errors. +""" + +from __future__ import annotations + +import time +from typing import Any, Literal, cast + +from pydantic import Field, field_validator + +from guidellm.scheduler.schemas import ( + SchedulerProgress, + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.schemas import RequestInfo +from guidellm.settings import settings + +from .constraint import Constraint, PydanticConstraintInitializer +from .factory import ConstraintsInitializerFactory + +__all__ = [ + "MaxErrorRateConstraint", + "MaxErrorsConstraint", + "MaxGlobalErrorRateConstraint", +] + + +@ConstraintsInitializerFactory.register( + ["max_errors", "max_err", "max_error", "max_errs"] +) +class MaxErrorsConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on absolute error count. + + Stops both request queuing and all request processing when the total number + of errored requests reaches the maximum threshold. Uses global error tracking + across all requests for immediate constraint evaluation. + """ + + type_: Literal["max_errors"] = "max_errors" # type: ignore[assignment] + max_errors: int | float | list[int | float] = Field( + description="Maximum number of errors allowed before triggering constraint", + ) + current_index: int = Field(default=-1, description="Current index in error list") + + @classmethod + def validated_kwargs( + cls, max_errors: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorsConstraint creation. + + :param max_errors: Maximum number of errors to allow + :param kwargs: Supports max_errors, max_err, max_error, max_errs, + and optional type_ + :return: Validated dictionary with max_errors and type_ fields + """ + aliases = ["max_errors", "max_err", "max_error", "max_errs"] + for alias in aliases: + if max_errors is None: + max_errors = kwargs.get(alias) + + return { + "max_errors": max_errors, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current error count. + + :param state: Current scheduler state with error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_errors = ( + self.max_errors + if isinstance(self.max_errors, int | float) + else self.max_errors[min(current_index, len(self.max_errors) - 1)] + ) + errors_exceeded = state.errored_requests >= max_errors + stop_time = ( + None if not errors_exceeded else request_info.completed_at or time.time() + ) + + return SchedulerUpdateAction( + request_queuing="stop" if errors_exceeded else "continue", + request_processing="stop_all" if errors_exceeded else "continue", + metadata={ + "max_errors": max_errors, + "errors_exceeded": errors_exceeded, + "current_errors": state.errored_requests, + "stop_time": stop_time, + }, + progress=SchedulerProgress(stop_time=stop_time), + ) + + @field_validator("max_errors") + @classmethod + def _validate_max_errors( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_errors must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0: + raise ValueError( + f"max_errors must be a positive num,received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( + ["max_error_rate", "max_err_rate", "max_errors_rate"] +) +class MaxErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on sliding window error rate. + + Tracks error status of recent requests in a sliding window and stops all + processing when the error rate exceeds the threshold. Only applies the + constraint after processing enough requests to fill the minimum window size + for statistical significance. + """ + + type_: Literal["max_error_rate"] = "max_error_rate" # type: ignore[assignment] + max_error_rate: int | float | list[int | float] = Field( + description="Maximum error rate allowed (0.0, 1.0)" + ) + window_size: int | float = Field( + default=30, + gt=0, + description="Size of sliding window for calculating error rate", + ) + error_window: list[bool] = Field( + default_factory=list, + description="Sliding window tracking error status of recent requests", + ) + current_index: int = Field( + default=-1, description="Current index in the error window" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_error_rate, max_err_rate, max_errors_rate, + optional window_size, and optional type_ + :return: Validated dictionary with max_error_rate, window_size, + and type_ fields + """ + aliases = ["max_error_rate", "max_err_rate", "max_errors_rate"] + for alias in aliases: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "window_size": kwargs.get( + "window_size", settings.constraint_error_window_size + ), + "error_window": kwargs.get("error_window", []), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create a new instance of MaxErrorRateConstraint (due to stateful window). + + :param kwargs: Additional keyword arguments (unused) + :return: New instance of the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against sliding window error rate. + + :param state: Current scheduler state with request counts + :param request_info: Individual request with completion status + :return: Action indicating whether to continue or stop operations + """ + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, int | float) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + if request_info.status in ["completed", "errored", "cancelled"]: + self.error_window.append(request_info.status == "errored") + if len(self.error_window) > self.window_size: + self.error_window.pop(0) + + error_count = sum(self.error_window) + window_requests = len(self.error_window) + error_rate = ( + error_count / float(window_requests) if window_requests > 0 else 0.0 + ) + exceeded_min_processed = state.processed_requests >= self.window_size + exceeded_error_rate = error_rate >= max_error_rate + exceeded = exceeded_min_processed and exceeded_error_rate + stop_time = None if not exceeded else request_info.completed_at or time.time() + + return SchedulerUpdateAction( + request_queuing="stop" if exceeded else "continue", + request_processing="stop_all" if exceeded else "continue", + metadata={ + "max_error_rate": max_error_rate, + "window_size": self.window_size, + "error_count": error_count, + "processed_count": state.processed_requests, + "current_window_size": len(self.error_window), + "current_error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + "exceeded": exceeded, + "stop_time": stop_time, + }, + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( + ["max_global_error_rate", "max_global_err_rate", "max_global_errors_rate"] +) +class MaxGlobalErrorRateConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on global error rate. + + Calculates error rate across all processed requests and stops all processing + when the rate exceeds the threshold. Only applies the constraint after + processing the minimum number of requests to ensure statistical significance + for global error rate calculations. + """ + + type_: Literal["max_global_error_rate"] = "max_global_error_rate" # type: ignore[assignment] + max_error_rate: int | float = Field( + description="Maximum error rate allowed (0.0 to 1.0)" + ) + min_processed: int | float | None = Field( + default=30, + gt=0, + description="Minimum requests processed before applying error rate constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_error_rate values" + ) + + @classmethod + def validated_kwargs( + cls, max_error_rate: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxGlobalErrorRateConstraint creation. + + :param max_error_rate: Maximum error rate to allow + :param kwargs: Supports max_global_error_rate, max_global_err_rate, + max_global_errors_rate, optional min_processed, and optional type_ + :return: Validated dictionary with max_error_rate, min_processed, + and type_ fields + """ + for alias in [ + "max_global_error_rate", + "max_global_err_rate", + "max_global_errors_rate", + ]: + if max_error_rate is None: + max_error_rate = kwargs.get(alias) + + return { + "max_error_rate": max_error_rate, + "min_processed": kwargs.get( + "min_processed", settings.constraint_error_min_processed + ), + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against global error rate. + + :param state: Current scheduler state with global request and error counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_error_rate = ( + self.max_error_rate + if isinstance(self.max_error_rate, int | float) + else self.max_error_rate[min(current_index, len(self.max_error_rate) - 1)] + ) + + exceeded_min_processed = ( + self.min_processed is None or state.processed_requests >= self.min_processed + ) + error_rate = ( + state.errored_requests / float(state.processed_requests) + if state.processed_requests > 0 + else 0.0 + ) + exceeded_error_rate = error_rate >= max_error_rate + exceeded = exceeded_min_processed and exceeded_error_rate + stop_time = None if not exceeded else request_info.completed_at or time.time() + + return SchedulerUpdateAction( + request_queuing="stop" if exceeded else "continue", + request_processing="stop_all" if exceeded else "continue", + metadata={ + "max_error_rate": max_error_rate, + "min_processed": self.min_processed, + "processed_requests": state.processed_requests, + "errored_requests": state.errored_requests, + "error_rate": error_rate, + "exceeded_min_processed": exceeded_min_processed, + "exceeded_error_rate": exceeded_error_rate, + "exceeded": exceeded, + "stop_time": stop_time, + }, + progress=SchedulerProgress(stop_time=stop_time), + ) + + @field_validator("max_error_rate") + @classmethod + def _validate_max_error_rate( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_error_rate must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0 or val >= 1: + raise ValueError( + "max_error_rate must be a number between 0 and 1," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value diff --git a/src/guidellm/scheduler/constraints/factory.py b/src/guidellm/scheduler/constraints/factory.py new file mode 100644 index 000000000..89627f9cc --- /dev/null +++ b/src/guidellm/scheduler/constraints/factory.py @@ -0,0 +1,182 @@ +""" +Factory for creating and managing constraint initializers. + +Provides centralized access to registered constraint types with support for +creating constraints from configuration dictionaries, simple values, or +pre-configured instances. +""" + +from __future__ import annotations + +from typing import Any + +from guidellm.scheduler.constraints.constraint import ( + Constraint, + ConstraintInitializer, + SerializableConstraintInitializer, + UnserializableConstraintInitializer, +) +from guidellm.utils import InfoMixin, RegistryMixin + +__all__ = ["ConstraintsInitializerFactory"] + + +class ConstraintsInitializerFactory(RegistryMixin[ConstraintInitializer]): + """ + Registry factory for creating and managing constraint initializers. + + Provides centralized access to registered constraint types with support for + creating constraints from configuration dictionaries, simple values, or + pre-configured instances. Handles constraint resolution and type validation + for the scheduler constraint system. + + Example: + :: + from guidellm.scheduler import ConstraintsInitializerFactory + + # Register new constraint type + @ConstraintsInitializerFactory.register("new_constraint") + class NewConstraint: + def create_constraint(self, **kwargs) -> Constraint: + return lambda state, request: SchedulerUpdateAction() + + # Create and use constraint + constraint = ConstraintsInitializerFactory.create_constraint("new_constraint") + """ + + @classmethod + def create(cls, key: str, *args, **kwargs) -> ConstraintInitializer: + """ + Create a constraint initializer for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for initializer creation + :param kwargs: Keyword arguments for initializer creation + :return: Configured constraint initializer instance + :raises ValueError: If the key is not registered in the factory + """ + if cls.registry is None or key not in cls.registry: + raise ValueError(f"Unknown constraint initializer key: {key}") + + initializer_class = cls.registry[key] + + return ( + initializer_class(*args, **kwargs) # type: ignore[operator] + if not isinstance(initializer_class, type) + or not issubclass(initializer_class, SerializableConstraintInitializer) + else initializer_class( + **initializer_class.validated_kwargs(*args, **kwargs) # type: ignore[misc] + ) + ) + + @classmethod + def serialize(cls, initializer: ConstraintInitializer) -> dict[str, Any]: + """ + Serialize constraint initializer to dictionary format. + + :param initializer: Constraint initializer to serialize + :return: Dictionary representation or unserializable placeholder + """ + if isinstance(initializer, SerializableConstraintInitializer): + return initializer.model_dump() + else: + unserializable = UnserializableConstraintInitializer( + orig_info=InfoMixin.extract_from_obj(initializer) + ) + return unserializable.model_dump() + + @classmethod + def deserialize( + cls, initializer_dict: dict[str, Any] + ) -> SerializableConstraintInitializer | UnserializableConstraintInitializer: + """ + Deserialize constraint initializer from dictionary format. + + :param initializer_dict: Dictionary representation of constraint initializer + :return: Reconstructed constraint initializer instance + :raises ValueError: If constraint type is unknown or cannot be deserialized + """ + if initializer_dict.get("type_") == "unserializable": + return UnserializableConstraintInitializer.model_validate(initializer_dict) + + if ( + cls.registry is not None + and initializer_dict.get("type_") + and initializer_dict["type_"] in cls.registry + ): + initializer_class = cls.registry[initializer_dict["type_"]] + if hasattr(initializer_class, "model_validate"): + return initializer_class.model_validate(initializer_dict) # type: ignore[return-value] + else: + return initializer_class(**initializer_dict) # type: ignore[return-value,operator] + + raise ValueError( + f"Cannot deserialize unknown constraint initializer: " + f"{initializer_dict.get('type_', 'unknown')}" + ) + + @classmethod + def create_constraint(cls, key: str, *args, **kwargs) -> Constraint: + """ + Create a constraint instance for the specified key. + + :param key: Registered constraint initializer key + :param args: Positional arguments for constraint creation + :param kwargs: Keyword arguments for constraint creation + :return: Configured constraint function ready for evaluation + :raises ValueError: If the key is not registered in the factory + """ + return cls.create(key, *args, **kwargs).create_constraint() + + @classmethod + def resolve( + cls, + initializers: dict[ + str, + Any | dict[str, Any] | Constraint | ConstraintInitializer, + ], + ) -> dict[str, Constraint]: + """ + Resolve mixed constraint specifications to callable constraints. + + :param initializers: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any key is not registered in the factory + """ + constraints = {} + + for key, val in initializers.items(): + if isinstance(val, Constraint): + constraints[key] = val + elif isinstance(val, ConstraintInitializer): + constraints[key] = val.create_constraint() + elif isinstance(val, dict): + constraints[key] = cls.create_constraint(key, **val) + else: + constraints[key] = cls.create_constraint(key, val) + + return constraints + + @classmethod + def resolve_constraints( + cls, + constraints: dict[str, Any | dict[str, Any] | Constraint], + ) -> dict[str, Constraint]: + """ + Resolve constraints from mixed constraint specifications. + + :param constraints: Dictionary mapping constraint keys to specifications + :return: Dictionary mapping constraint keys to callable functions + :raises ValueError: If any constraint key is not registered + """ + resolved_constraints = {} + + for key, val in constraints.items(): + if isinstance(val, Constraint): + resolved_constraints[key] = val + elif isinstance(val, dict): + resolved_constraints[key] = cls.create_constraint(key, **val) + else: + resolved_constraints[key] = cls.create_constraint(key, val) + + return resolved_constraints diff --git a/src/guidellm/scheduler/constraints/request.py b/src/guidellm/scheduler/constraints/request.py new file mode 100644 index 000000000..764087673 --- /dev/null +++ b/src/guidellm/scheduler/constraints/request.py @@ -0,0 +1,311 @@ +""" +Request-based constraint implementations. + +Provides constraint types for limiting benchmark execution based on request counts +and time duration. These constraints monitor request creation, processing, and +elapsed time to determine when to stop benchmark execution. +""" + +from __future__ import annotations + +import time +from typing import Any, Literal, cast + +from pydantic import Field, field_validator + +from guidellm.scheduler.constraints.constraint import ( + Constraint, + PydanticConstraintInitializer, +) +from guidellm.scheduler.constraints.factory import ConstraintsInitializerFactory +from guidellm.scheduler.schemas import ( + SchedulerProgress, + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.schemas import RequestInfo, StandardBaseModel +from guidellm.utils import InfoMixin + +__all__ = [ + "MaxDurationConstraint", + "MaxNumberConstraint", + "RequestsExhaustedConstraint", +] + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["max_number", "max_num", "max_requests", "max_req"] +) +class MaxNumberConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum request counts. + + Stops request queuing when created requests reach the limit and stops local + request processing when processed requests reach the limit. Provides progress + tracking based on remaining requests and completion fraction. + """ + + type_: Literal["max_number"] = "max_number" # type: ignore[assignment] + max_num: int | float | list[int | float] = Field( + description="Maximum number of requests allowed before triggering constraint", + ) + current_index: int = Field( + default=-1, description="Current index for list-based max_num values" + ) + + @classmethod + def validated_kwargs( + cls, max_num: int | float | list[int | float], **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxNumberConstraint creation. + + :param max_num: Maximum number of requests to allow + :param kwargs: Supports max_num, max_number, max_requests, max_req, + and optional type_ + :return: Validated dictionary with max_num and type_ fields + """ + aliases = ["max_number", "max_num", "max_requests", "max_req"] + for alias in aliases: + if max_num is None: + max_num = kwargs.get(alias) + + return {"max_num": max_num, "current_index": kwargs.get("current_index", -1)} + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and request count. + + :param state: Current scheduler state with request counts + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_num = ( + self.max_num + if isinstance(self.max_num, int | float) + else self.max_num[min(current_index, len(self.max_num) - 1)] + ) + + create_exceeded = state.created_requests >= max_num + processed_exceeded = state.processed_requests >= max_num + remaining_requests = min(max(0, max_num - state.processed_requests), max_num) + stop_time = ( + None if remaining_requests > 0 else request_info.completed_at or time.time() + ) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "max_number": max_num, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_requests": remaining_requests, + "stop_time": stop_time, + }, + progress=SchedulerProgress( + remaining_requests=remaining_requests, + total_requests=max_num, + stop_time=stop_time, + ), + ) + + @field_validator("max_num") + @classmethod + def _validate_max_num( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + f"max_num must be set and truthful, received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0: + raise ValueError( + f"max_num must be a positive num, received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +@ConstraintsInitializerFactory.register( + ["max_duration", "max_dur", "max_sec", "max_seconds", "max_min", "max_minutes"] +) +class MaxDurationConstraint(PydanticConstraintInitializer): + """ + Constraint that limits execution based on maximum time duration. + + Stops both request queuing and processing when the elapsed time since scheduler + start exceeds the maximum duration. Provides progress tracking based on + remaining time and completion fraction. + """ + + type_: Literal["max_duration"] = "max_duration" # type: ignore[assignment] + max_duration: int | float | list[int | float] = Field( + description="Maximum duration in seconds before triggering constraint" + ) + current_index: int = Field(default=-1, description="Current index in duration list") + + @classmethod + def validated_kwargs( + cls, max_duration: int | float | list[int | float] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for MaxDurationConstraint creation. + + :param max_duration: Maximum duration in seconds + :param kwargs: Supports max_duration, max_dur, max_sec, max_seconds, + max_min, max_minutes, and optional type_ + :return: Validated dictionary with max_duration and type_ fields + """ + seconds_aliases = ["max_dur", "max_sec", "max_seconds"] + for alias in seconds_aliases: + if max_duration is None: + max_duration = kwargs.get(alias) + minutes_aliases = ["max_min", "max_minutes"] + for alias in minutes_aliases: + minutes = kwargs.get(alias) + if minutes is not None and max_duration is None: + max_duration = minutes * 60 + + return { + "max_duration": max_duration, + "current_index": kwargs.get("current_index", -1), + } + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Return self as the constraint instance. + + :param kwargs: Additional keyword arguments (unused) + :return: Self instance as the constraint + """ + self.current_index += 1 + + return cast("Constraint", self.model_copy()) + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state and elapsed time. + + :param state: Current scheduler state with start time + :param request_info: Individual request information (unused) + :return: Action indicating whether to continue or stop operations + """ + _ = request_info # Unused parameters + current_index = max(0, self.current_index) + max_duration = ( + self.max_duration + if isinstance(self.max_duration, int | float) + else self.max_duration[min(current_index, len(self.max_duration) - 1)] + ) + + current_time = time.time() + elapsed = current_time - state.start_time + duration_exceeded = elapsed >= max_duration + remaining_duration = min(max(0.0, max_duration - elapsed), max_duration) + stop_time = None if not duration_exceeded else state.start_time + max_duration + + return SchedulerUpdateAction( + request_queuing="stop" if duration_exceeded else "continue", + request_processing="stop_local" if duration_exceeded else "continue", + metadata={ + "max_duration": max_duration, + "elapsed_time": elapsed, + "duration_exceeded": duration_exceeded, + "start_time": state.start_time, + "current_time": current_time, + "stop_time": stop_time, + }, + progress=SchedulerProgress( + remaining_duration=remaining_duration, + total_duration=max_duration, + stop_time=stop_time, + ), + ) + + @field_validator("max_duration") + @classmethod + def _validate_max_duration( + cls, value: int | float | list[int | float] + ) -> int | float | list[int | float]: + if not isinstance(value, list): + value = [value] + for val in value: + if not val: + raise ValueError( + "max_duration must be set and truthful, " + f"received {value} ({val} failed)" + ) + if not isinstance(val, int | float) or val <= 0: + raise ValueError( + "max_duration must be a positive num," + f"received {value} ({val} failed)" + ) + + return value[0] if isinstance(value, list) and len(value) == 1 else value + + +class RequestsExhaustedConstraint(StandardBaseModel, InfoMixin): + type_: Literal["requests_exhausted"] = "requests_exhausted" # type: ignore[assignment] + num_requests: int + + @property + def info(self) -> dict[str, Any]: + """ + Extract serializable information from this constraint initializer. + + :return: Dictionary containing constraint configuration and metadata + """ + return self.model_dump() + + def __call__( + self, state: SchedulerState, request: RequestInfo + ) -> SchedulerUpdateAction: + _ = request # Unused parameter + create_exceeded = state.created_requests >= self.num_requests + processed_exceeded = state.processed_requests >= self.num_requests + remaining_requests = max(0, self.num_requests - state.processed_requests) + stop_time = ( + None if remaining_requests > 0 else request.completed_at or time.time() + ) + + return SchedulerUpdateAction( + request_queuing="stop" if create_exceeded else "continue", + request_processing="stop_local" if processed_exceeded else "continue", + metadata={ + "num_requests": self.num_requests, + "create_exceeded": create_exceeded, + "processed_exceeded": processed_exceeded, + "created_requests": state.created_requests, + "processed_requests": state.processed_requests, + "remaining_requests": remaining_requests, + "stop_time": stop_time, + }, + progress=SchedulerProgress( + remaining_requests=remaining_requests, + total_requests=self.num_requests, + stop_time=stop_time, + ), + ) diff --git a/src/guidellm/scheduler/constraints/saturation.py b/src/guidellm/scheduler/constraints/saturation.py new file mode 100644 index 000000000..b02013ed6 --- /dev/null +++ b/src/guidellm/scheduler/constraints/saturation.py @@ -0,0 +1,722 @@ +""" +Over-saturation detection constraint implementation. + +This module implements the Over-Saturation Detection (OSD) algorithm for detecting +when a model becomes over-saturated during benchmarking. Over-saturation occurs when +the response rate doesn't keep up with the request rate, leading to degraded +performance. + +Algorithm Overview: +------------------- +The OSD algorithm uses statistical slope detection to identify over-saturation: + +1. **Slope Detection**: The algorithm tracks two key metrics over time: + - Concurrent requests: Number of requests being processed simultaneously + - Time-to-first-token (TTFT): Latency for the first token of each response + +2. **Statistical Analysis**: For each metric, the algorithm: + - Maintains a sliding window of recent data points + - Calculates the linear regression slope using online statistics + - Computes the margin of error (MOE) using t-distribution confidence intervals + - Detects positive slopes with low MOE, indicating degradation + +3. **Detection Criteria**: Over-saturation is detected when: + - Both concurrent requests and TTFT show statistically significant positive slopes + - The minimum duration threshold has been met + - Sufficient data points are available for reliable slope estimation + +4. **Window Management**: The algorithm maintains bounded memory by: + - Limiting window size by time (maximum_window_seconds) + - Limiting window size by ratio of total requests (maximum_window_ratio) + - Automatically pruning old data points + +5. **Constraint Integration**: When over-saturation is detected, the constraint: + - Stops request queuing to prevent further degradation + - Stops processing of existing requests (if enabled) + - Provides detailed metadata about detection state + +Key Parameters: +--------------- +- minimum_duration: Minimum seconds before checking for over-saturation (default: 30.0) +- minimum_ttft: Minimum TTFT threshold for violation counting (default: 2.5) +- maximum_window_seconds: Maximum time window for data retention (default: 120.0) +- moe_threshold: Margin of error threshold for slope detection (default: 2.0) +- maximum_window_ratio: Maximum window size as ratio of total requests (default: 0.75) +- minimum_window_size: Minimum data points required for slope estimation (default: 5) +- confidence: Statistical confidence level for t-distribution (default: 0.95) + +The constraint integrates with the scheduler by evaluating each request update and +providing scheduler actions (continue/stop) based on the current over-saturation state. +""" + +from __future__ import annotations + +import math +import time +from typing import Any, Literal + +from pydantic import Field + +from guidellm.scheduler.constraints.constraint import ( + Constraint, + PydanticConstraintInitializer, +) +from guidellm.scheduler.constraints.factory import ConstraintsInitializerFactory +from guidellm.scheduler.schemas import ( + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.schemas import RequestInfo + +__all__ = [ + "OverSaturationConstraint", + "OverSaturationConstraintInitializer", + "SlopeChecker", + "approx_t_ppf", +] + + +def approx_t_ppf(p: float, df: float) -> float: + """ + Approximate the percent point function (PPF) for the t-distribution. + + Provides a fast approximation of the t-distribution PPF using numerical + methods from Abramowitz & Stegun. This function is significantly faster + than scipy.stats.t.ppf while providing sufficient accuracy for statistical + slope detection in over-saturation detection. Used internally by SlopeChecker + for calculating confidence intervals and margin of error. + + Reference: + Milton Abramowitz and Irene A. Stegun (Eds.). (1965). + Handbook of Mathematical Functions: with Formulas, Graphs, + and Mathematical Tables. Dover Publications. + + An electronic version of this book is available at: + https://personal.math.ubc.ca/~cbm/aands/. + + :param p: The probability value (e.g., 0.975 for a 95% confidence interval) + :param df: The degrees of freedom for the t-distribution + :return: Approximate t-distribution PPF value, or NaN if df <= 0 + """ + dof = df + if dof <= 0: + return float("nan") + + # 1. Approximate the PPF of the Normal distribution (z-score) + # Uses Abramowitz & Stegun formula 26.2.23. + c = [2.515517, 0.802853, 0.010328] + d = [1.432788, 0.189269, 0.001308] + + numerical_stability_threshold = 0.5 + if p < numerical_stability_threshold: + t = math.sqrt(-2.0 * math.log(p)) + z = -( + t + - ((c[2] * t + c[1]) * t + c[0]) + / (((d[2] * t + d[1]) * t + d[0]) * t + 1.0) + ) + else: + t = math.sqrt(-2.0 * math.log(1.0 - p)) + z = t - ((c[2] * t + c[1]) * t + c[0]) / ( + ((d[2] * t + d[1]) * t + d[0]) * t + 1.0 + ) + + # 2. Convert the z-score to a t-score + # Uses the Cornish-Fisher expansion (first few terms). + z2 = z * z + z3 = z2 * z + z4 = z3 * z + + g1 = (z3 + z) / 4.0 + g2 = (5.0 * z4 + 16.0 * z3 + 3.0 * z2) / 96.0 + + # Adjust z using the degrees of freedom (dof) + return z + g1 / dof + g2 / (dof * dof) + + +class SlopeChecker: + """ + Helper class for online slope detection using linear regression. + + Maintains running statistics for efficient O(1) updates and provides + statistical slope detection with margin of error calculation. Uses online + algorithms to compute linear regression statistics incrementally without + storing all data points, enabling memory-efficient slope detection for + over-saturation detection. Supports adding and removing data points + dynamically while maintaining accurate statistical measures. + + Example: + :: + checker = SlopeChecker(moe_threshold=2.0, confidence=0.95) + checker.add_data_point(1.0, 2.0) + checker.add_data_point(2.0, 3.0) + checker.add_data_point(3.0, 4.0) + is_positive = checker.check_slope(3.0) # True for positive slope + """ + + def __init__( + self, moe_threshold: float = 1.0, confidence: float = 0.95, eps: float = 1e-12 + ) -> None: + """ + Initialize slope checker with statistical parameters. + + :param moe_threshold: Maximum margin of error threshold for slope detection + :param confidence: Statistical confidence level for t-distribution (0-1) + :param eps: Epsilon value for numerical stability in calculations + """ + self.n = 0 + self.sum_x = 0.0 + self.sum_y = 0.0 + self.sum_xy = 0.0 + self.sum_x2 = 0.0 + self.sum_y2 = 0.0 + self.moe_threshold = moe_threshold + self.eps = eps + self.confidence = confidence + self.slope: float | None = None + self.margin_of_error: float | None = None + + def add_data_point(self, x_new: float, y_new: float) -> None: + """ + Integrate a new data point into the accumulated statistics. + + Updates running sums for linear regression calculation in O(1) time. + The data point is incorporated into the statistical model without + storing the individual value, enabling memory-efficient slope detection. + + :param x_new: The new x-coordinate (typically time or duration) + :param y_new: The new y-coordinate (typically metric value like TTFT + or concurrent requests) + """ + self.n += 1 + self.sum_x += x_new + self.sum_y += y_new + self.sum_xy += x_new * y_new + self.sum_x2 += x_new**2 + self.sum_y2 += y_new**2 + + def remove_data_point(self, x_old: float, y_old: float) -> None: + """ + Remove a data point from the accumulated statistics. + + Updates running sums by subtracting the specified data point in O(1) time. + Used for window management when pruning old data points to maintain + bounded memory usage while preserving statistical accuracy. + + :param x_old: The x-coordinate to remove (typically time or duration) + :param y_old: The y-coordinate to remove (typically metric value) + """ + self.n -= 1 + self.sum_x -= x_old + self.sum_y -= y_old + self.sum_xy -= x_old * y_old + self.sum_x2 -= x_old**2 + self.sum_y2 -= y_old**2 + + def check_slope(self, effective_n: float) -> bool: + """ + Check if there is a statistically significant positive slope. + + Calculates linear regression slope and margin of error using online + statistics. Returns True if the slope is positive and the margin of + error is below the threshold, indicating statistically significant + degradation. Updates internal slope and margin_of_error attributes + for external inspection. + + :param effective_n: Effective sample size for slope estimation (may differ + from actual n for correlation adjustment) + :return: True if positive slope detected with margin of error below threshold + """ + minimal_n_for_slope_estimation = 3 + if effective_n < minimal_n_for_slope_estimation: + return False + + # Calculate sums of squares and cross-products + # These formulas are numerically stable for online calculation. + centered_sum_xx = self.sum_x2 - (self.sum_x**2) / self.n + centered_sum_xy = self.sum_xy - (self.sum_x * self.sum_y) / self.n + centered_sum_yy = self.sum_y2 - (self.sum_y**2) / self.n + + # Safeguard against division by zero for SS_xx + centered_sum_xx_safe = max(centered_sum_xx, self.eps) + + slope = centered_sum_xy / centered_sum_xx_safe + + # Calculate Residual Sum of Squares (RSS) + # This is a direct calculation using the sums of squares. + residual_sum_of_squares = centered_sum_yy - ( + centered_sum_xy**2 / centered_sum_xx_safe + ) + + # Ensure RSS is non-negative due to potential floating point inaccuracies + residual_sum_of_squares = max(residual_sum_of_squares, 0.0) + + # Degrees of freedom for standard error (n - 2 for simple linear regression) + dof = effective_n - 2 + + residual_variance = residual_sum_of_squares / dof + standard_error = (residual_variance / centered_sum_xx_safe) ** 0.5 + + # t-critical value + alpha = 1 - self.confidence + t_crit = approx_t_ppf(1 - alpha / 2, df=dof) + + # Margin Of Error + margin_of_error = t_crit * standard_error / max(slope, self.eps) + + self.slope = slope + self.margin_of_error = margin_of_error + return (slope > 0) and (margin_of_error < self.moe_threshold) + + +class OverSaturationConstraint(Constraint): + """ + Constraint that detects and stops execution when over-saturation is detected. + + This constraint implements the Over-Saturation Detection (OSD) algorithm to + identify when a model becomes over-saturated (response rate doesn't keep up with + request rate). When over-saturation is detected, the constraint stops request + queuing and optionally stops processing of existing requests. + + The constraint maintains internal state for tracking concurrent requests and + time-to-first-token (TTFT) metrics, using statistical slope detection to identify + performance degradation patterns. + """ + + def __init__( + self, + minimum_duration: float = 30.0, + minimum_ttft: float = 2.5, + maximum_window_seconds: float = 120.0, + moe_threshold: float = 2.0, + maximum_window_ratio: float = 0.75, + minimum_window_size: int = 5, + confidence: float = 0.95, + eps: float = 1e-12, + enabled: bool = True, + ) -> None: # noqa: PLR0913 + """ + Initialize the over-saturation constraint. + + Creates a new constraint instance with specified detection parameters. + The constraint will track concurrent requests and TTFT metrics, using + statistical slope detection to identify when the model becomes + over-saturated. All parameters have sensible defaults suitable for + most benchmarking scenarios. + + :param minimum_duration: Minimum seconds before checking for over-saturation + (default: 30.0) + :param minimum_ttft: Minimum TTFT threshold in seconds for violation counting + (default: 2.5) + :param maximum_window_seconds: Maximum time window in seconds for data retention + (default: 120.0) + :param moe_threshold: Margin of error threshold for slope detection + (default: 2.0) + :param maximum_window_ratio: Maximum window size as ratio of total requests + (default: 0.75) + :param minimum_window_size: Minimum data points required for slope estimation + (default: 5) + :param confidence: Statistical confidence level for t-distribution (0-1) + (default: 0.95) + :param eps: Epsilon for numerical stability in calculations + (default: 1e-12) + :param enabled: Whether to actually stop when over-saturation is detected + (default: True) + """ + self.minimum_duration = minimum_duration + self.minimum_ttft = minimum_ttft + self.maximum_window_seconds = maximum_window_seconds + self.maximum_window_ratio = maximum_window_ratio + self.minimum_window_size = minimum_window_size + self.moe_threshold = moe_threshold + self.confidence = confidence + self.eps = eps + self.enabled = enabled + self.reset() + + @property + def info(self) -> dict[str, Any]: + """ + Get current constraint configuration and state information. + :return: Dictionary containing configuration parameters. + """ + + return { + "type_": "over_saturation", + "minimum_duration": self.minimum_duration, + "minimum_ttft": self.minimum_ttft, + "maximum_window_seconds": self.maximum_window_seconds, + "maximum_window_ratio": self.maximum_window_ratio, + "minimum_window_size": self.minimum_window_size, + "moe_threshold": self.moe_threshold, + "confidence": self.confidence, + "enabled": self.enabled, + } + + def reset(self) -> None: + """ + Reset all internal state to initial values. + + Clears all tracked requests, resets counters, and reinitializes slope + checkers. Useful for reusing constraint instances across multiple + benchmark runs or resetting state after configuration changes. + """ + self.duration = 0.0 + self.started_requests: list[dict[str, Any]] = [] + self.finished_requests: list[dict[str, Any]] = [] + self.ttft_violations_counter = 0 + self.total_finished_ever = 0 + self.total_started_ever = 0 + self.concurrent_slope_checker = SlopeChecker( + moe_threshold=self.moe_threshold, confidence=self.confidence, eps=self.eps + ) + self.ttft_slope_checker = SlopeChecker( + moe_threshold=self.moe_threshold, confidence=self.confidence, eps=self.eps + ) + + def _add_finished(self, request: dict[str, Any]) -> None: + """ + Add a finished request to tracking. + + :param request: Dictionary containing request data with 'ttft' and + 'duration' keys. + """ + ttft = request["ttft"] + duration = request["duration"] + if ttft is not None: + self.total_finished_ever += 1 + self.finished_requests.append(request) + if ttft > self.minimum_ttft: + self.ttft_violations_counter += 1 + self.ttft_slope_checker.add_data_point(duration, ttft) + + def _remove_finished(self, request: dict[str, Any]) -> None: + """ + Remove a finished request from tracking. + + :param request: Dictionary containing request data with 'ttft' and + 'duration' keys. + """ + del self.finished_requests[0] + ttft = request["ttft"] + duration = request["duration"] + if ttft > self.minimum_ttft: + self.ttft_violations_counter -= 1 + self.ttft_slope_checker.remove_data_point(duration, ttft) + + def _add_started(self, request: dict[str, Any]) -> None: + """ + Add a started request to tracking. + + :param request: Dictionary containing request data with + 'concurrent_requests' and 'duration' keys. + """ + concurrent = request["concurrent_requests"] + duration = request["duration"] + if concurrent is not None: + self.total_started_ever += 1 + self.started_requests.append(request) + self.concurrent_slope_checker.add_data_point(duration, concurrent) + + def _remove_started(self, request: dict[str, Any]) -> None: + """ + Remove a started request from tracking. + + :param request: Dictionary containing request data with + 'concurrent_requests' and 'duration' keys. + """ + del self.started_requests[0] + concurrent = request["concurrent_requests"] + duration = request["duration"] + self.concurrent_slope_checker.remove_data_point(duration, concurrent) + + def _update_duration(self, duration: float) -> None: + """ + Update duration and prune old data points. + + Updates the current duration and removes data points that exceed the maximum + window size (by ratio or time) to maintain bounded memory usage. + + :param duration: Current duration in seconds since benchmark start. + """ + self.duration = duration + + maximum_finished_window_size = int( + self.total_finished_ever * self.maximum_window_ratio + ) + while len(self.finished_requests) > maximum_finished_window_size: + self._remove_finished(self.finished_requests[0]) + + while (len(self.finished_requests) > 0) and ( + ( + time_since_earliest_request := duration + - self.finished_requests[0]["duration"] + ) + > self.maximum_window_seconds + ): + self._remove_finished(self.finished_requests[0]) + + maximum_started_window_size = int( + self.total_started_ever * self.maximum_window_ratio + ) + while len(self.started_requests) > maximum_started_window_size: + self._remove_started(self.started_requests[0]) + + while (len(self.started_requests) > 0) and ( + ( + time_since_earliest_request := duration # noqa: F841 + - self.started_requests[0]["duration"] + ) + > self.maximum_window_seconds + ): + self._remove_started(self.started_requests[0]) + + def _check_alert(self) -> bool: + """ + Check if over-saturation is currently detected. + + :return: True if over-saturation is detected, False otherwise. + """ + # Use duration as the maximum n value since requests from the + # same second are highly correlated, this is simple and good enough + # given that the MOE has a custom threshold anyway. + concurrent_n = min(self.duration, self.concurrent_slope_checker.n) + ttft_n = min(self.duration, self.ttft_slope_checker.n) + + if ( + (self.duration < self.minimum_duration) + or (self.ttft_slope_checker.n > self.ttft_violations_counter * 2) + or (self.duration < self.minimum_ttft) + or (concurrent_n < self.minimum_window_size) + ): + return False + + is_concurrent_slope_positive = self.concurrent_slope_checker.check_slope( + concurrent_n + ) + + if ttft_n < self.minimum_window_size: + return is_concurrent_slope_positive + + is_ttft_slope_positive = self.ttft_slope_checker.check_slope(ttft_n) + + return is_concurrent_slope_positive and is_ttft_slope_positive + + def __call__( + self, state: SchedulerState, request_info: RequestInfo + ) -> SchedulerUpdateAction: + """ + Evaluate constraint against current scheduler state. + + :param state: Current scheduler state. + :param request_info: Individual request information. + :return: Action indicating whether to continue or stop operations. + """ + duration = time.time() - state.start_time + + if request_info.status == "in_progress": + concurrent_requests = state.processing_requests + self._add_started( + {"concurrent_requests": concurrent_requests, "duration": duration} + ) + elif ( + request_info.status == "completed" + and request_info.timings + and request_info.timings.first_token_iteration + and request_info.timings.request_start + ): + ttft = ( + request_info.timings.first_token_iteration + - request_info.timings.request_start + ) + self._add_finished({"ttft": ttft, "duration": duration}) + + self._update_duration(duration) + is_over_saturated = self._check_alert() + + ttft_slope = self.ttft_slope_checker.slope + ttft_slope_moe = self.ttft_slope_checker.margin_of_error + ttft_n = self.ttft_slope_checker.n + ttft_violations = self.ttft_violations_counter + concurrent_slope = self.concurrent_slope_checker.slope + concurrent_slope_moe = self.concurrent_slope_checker.margin_of_error + concurrent_n = self.concurrent_slope_checker.n + + should_stop = is_over_saturated and self.enabled + return SchedulerUpdateAction( + request_queuing="stop" if should_stop else "continue", + request_processing="stop_all" if should_stop else "continue", + metadata={ + "ttft_slope": ttft_slope, + "ttft_slope_moe": ttft_slope_moe, + "ttft_n": ttft_n, + "ttft_violations": ttft_violations, + "concurrent_slope": concurrent_slope, + "concurrent_slope_moe": concurrent_slope_moe, + "concurrent_n": concurrent_n, + "is_over_saturated": is_over_saturated, + }, + ) + + +@ConstraintsInitializerFactory.register( # type: ignore[arg-type] + ["over_saturation", "detect_saturation"] +) +class OverSaturationConstraintInitializer(PydanticConstraintInitializer): + """ + Factory for creating OverSaturationConstraint instances from configuration. + + Provides a Pydantic-based initializer for over-saturation detection constraints + with support for flexible configuration patterns. Supports detailed configuration + dictionaries, enabling easy integration with CLI arguments, configuration files, + and programmatic constraint creation. + + Example: + :: + # Configuration with defaults + initializer = OverSaturationConstraintInitializer(enabled=True) + constraint = initializer.create_constraint() + + # Detailed configuration + initializer = OverSaturationConstraintInitializer( + enabled=True, + min_seconds=60.0, + max_window_seconds=300.0, + moe_threshold=1.5 + ) + constraint = initializer.create_constraint() + + :cvar type_: Always "over_saturation" to identify this constraint type + :cvar enabled: Whether to stop the benchmark if over-saturation is detected + :cvar min_seconds: Minimum seconds before checking for over-saturation + :cvar max_window_seconds: Maximum time window for data retention + :cvar moe_threshold: Margin of error threshold for slope detection + :cvar minimum_ttft: Minimum TTFT threshold for violation counting + :cvar maximum_window_ratio: Maximum window size as ratio of total requests + :cvar minimum_window_size: Minimum data points required for slope estimation + :cvar confidence: Statistical confidence level for t-distribution + """ + + type_: Literal["over_saturation"] = "over_saturation" # type: ignore[assignment] + enabled: bool = Field( + default=True, + description="Whether to stop the benchmark if the model is over-saturated", + ) + min_seconds: int | float = Field( + default=30.0, + ge=0, + description="Minimum seconds before checking for over-saturation", + ) + max_window_seconds: int | float = Field( + default=120.0, + ge=0, + description="Maximum over-saturation checking window size in seconds", + ) + moe_threshold: float = Field( + default=2.0, + ge=0, + description="Margin of error threshold for slope detection", + ) + minimum_ttft: float = Field( + default=2.5, + ge=0, + description="Minimum TTFT threshold for violation counting", + ) + maximum_window_ratio: float = Field( + default=0.75, + ge=0, + le=1.0, + description="Maximum window size as ratio of total requests", + ) + minimum_window_size: int = Field( + default=5, + ge=0, + description="Minimum data points required for slope estimation", + ) + confidence: float = Field( + default=0.95, + ge=0, + le=1.0, + description="Statistical confidence level for t-distribution", + ) + + def create_constraint(self, **_kwargs) -> Constraint: + """ + Create an OverSaturationConstraint instance from this initializer. + + Constructs a new OverSaturationConstraint with the configuration parameters + specified in this initializer. The constraint will be ready for evaluation + against scheduler state and requests. + + :param _kwargs: Additional keyword arguments (unused) + :return: Configured OverSaturationConstraint instance ready for use + """ + return OverSaturationConstraint( + minimum_duration=self.min_seconds, + minimum_ttft=self.minimum_ttft, + maximum_window_seconds=self.max_window_seconds, + moe_threshold=self.moe_threshold, + maximum_window_ratio=self.maximum_window_ratio, + minimum_window_size=self.minimum_window_size, + confidence=self.confidence, + enabled=self.enabled, + ) + + @classmethod + def validated_kwargs( + cls, over_saturation: dict[str, Any] | None = None, **kwargs + ) -> dict[str, Any]: + """ + Validate and process arguments for OverSaturationConstraint creation. + + Processes flexible input formats to create validated constraint + configuration. Supports dictionary inputs for detailed configuration, and + alias parameters for compatibility. Handles parameter normalization and + default value application. + + :param over_saturation: Dictionary with configuration parameters + (min_seconds, max_window_seconds, etc.) + :param kwargs: Additional keyword arguments supporting aliases like + "detect_saturation" for compatibility, or unpacked dict values when + dict is passed to factory + :return: Validated dictionary with constraint configuration ready for + initializer creation + """ + # Check for aliases in kwargs + aliases = ["over_saturation", "detect_saturation"] + result: dict[str, Any] | None = over_saturation + + for alias in aliases: + alias_value = kwargs.get(alias) + if alias_value is not None: + result = alias_value + break + + # If over_saturation is None but kwargs contain constraint parameters, + # treat kwargs as an unpacked dict (happens when dict is passed to factory) + if result is None and kwargs: + constraint_keys = { + "enabled", + "min_seconds", + "max_window_seconds", + "moe_threshold", + "minimum_ttft", + "maximum_window_ratio", + "minimum_window_size", + "confidence", + } + if any(key in kwargs for key in constraint_keys): + # Reconstruct dict from kwargs + result = {key: kwargs[key] for key in constraint_keys if key in kwargs} + + if result is None: + return {"enabled": False} + + if isinstance(result, dict): + # Return dict as-is, defaults come from fields above + return result + else: + # Type signature only accepts dict or None, so this should never happen + raise TypeError( + f"over_saturation must be a dict or None, got {type(result).__name__}" + ) diff --git a/src/guidellm/utils/cli.py b/src/guidellm/utils/cli.py index d81f061db..90c0271ed 100644 --- a/src/guidellm/utils/cli.py +++ b/src/guidellm/utils/cli.py @@ -77,9 +77,16 @@ def parse_list_floats(ctx, param, value): ) from err -def parse_json(ctx, param, value): # noqa: ARG001 +def parse_json(ctx, param, value): # noqa: ARG001, C901, PLR0911 + if isinstance(value, dict): + return value + if value is None or value == [None]: return None + + if isinstance(value, str) and not value.strip(): + return None + if isinstance(value, list | tuple): return [parse_json(ctx, param, val) for val in value] diff --git a/tests/e2e/README.md b/tests/e2e/README.md index db5fcba5a..39a994340 100644 --- a/tests/e2e/README.md +++ b/tests/e2e/README.md @@ -6,10 +6,10 @@ The E2E tests in GuideLLM use the [vLLM simulator by llm-d](https://llm-d.ai/doc docker build . -f tests/e2e/vllm-sim.Dockerfile -o type=local,dest=./ ``` -On MacOS run: +For MacOS native: ```shell -docker build . -f tests/e2e/vllm-sim.Dockerfile -o type=local,dest=./ --build-arg BUILDOS=darwin +docker build . -f tests/e2e/vllm-sim-macos.Dockerfile -o type=local,dest=./ ``` Then to run the tests: diff --git a/tests/e2e/test_max_error_benchmark.py b/tests/e2e/test_max_error_benchmark.py index c222e9222..aa0542139 100644 --- a/tests/e2e/test_max_error_benchmark.py +++ b/tests/e2e/test_max_error_benchmark.py @@ -8,7 +8,6 @@ GuidellmClient, assert_constraint_triggered, assert_no_python_exceptions, - cleanup_report_file, load_benchmark_report, ) from tests.e2e.vllm_sim_server import VllmSimServer @@ -35,44 +34,45 @@ def server(): @pytest.mark.timeout(30) -def test_max_error_benchmark(server: VllmSimServer): +def test_max_error_benchmark(server: VllmSimServer, tmp_path: Path): """ Test that the max error rate constraint is properly triggered when server goes down. """ - report_path = Path("tests/e2e/max_error_benchmarks.json") + report_name = "max_error_benchmarks.json" + report_path = tmp_path / report_name rate = 10 max_error_rate = 0.1 # Create and configure the guidellm client - client = GuidellmClient(target=server.get_url(), output_path=report_path) - - try: - # Start the benchmark - client.start_benchmark( - rate=rate, - max_seconds=25, - max_error_rate=max_error_rate, - ) + client = GuidellmClient( + target=server.get_url(), + output_dir=tmp_path, + outputs=report_name, + ) - # Wait for the benchmark to complete (server will be stopped after 15 seconds) - client.wait_for_completion(timeout=30, stop_server_after=15, server=server) + # Start the benchmark + client.start_benchmark( + rate=rate, + max_seconds=25, + max_error_rate=max_error_rate, + ) - # Assert no Python exceptions occurred - assert_no_python_exceptions(client.stderr) + # Wait for the benchmark to complete (server will be stopped after 15 seconds) + client.wait_for_completion(timeout=30, stop_server_after=15, server=server) - # Load and validate the report - report = load_benchmark_report(report_path) - benchmark = report["benchmarks"][0] + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) - # Check that the max error rate constraint was triggered - assert_constraint_triggered( - benchmark, - "max_error_rate", - { - "exceeded_error_rate": True, - "current_error_rate": lambda rate: rate >= max_error_rate, - }, - ) + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] - finally: - cleanup_report_file(report_path) + # Check that the max error rate constraint was triggered + assert_constraint_triggered( + benchmark, + "max_error_rate", + { + "exceeded_error_rate": True, + "current_error_rate": lambda rate: rate >= max_error_rate, + }, + ) diff --git a/tests/e2e/test_over_saturated_benchmark.py b/tests/e2e/test_over_saturated_benchmark.py index 368e2c0f2..711e5e2ed 100644 --- a/tests/e2e/test_over_saturated_benchmark.py +++ b/tests/e2e/test_over_saturated_benchmark.py @@ -6,7 +6,6 @@ GuidellmClient, assert_constraint_triggered, assert_no_python_exceptions, - cleanup_report_file, load_benchmark_report, ) from tests.e2e.vllm_sim_server import VllmSimServer @@ -22,7 +21,7 @@ def server(): port=8000, model="databricks/dolly-v2-12b", mode="random", - time_to_first_token=10000, + time_to_first_token=60000, inter_token_latency=100, max_num_seqs=1, ) @@ -33,26 +32,28 @@ def server(): server.stop() # Teardown: Stop the server after tests are done -@pytest.mark.skip(reason="Skipping future feature test") @pytest.mark.timeout(60) -def test_over_saturated_benchmark(server: VllmSimServer): +def test_over_saturated_benchmark(server: VllmSimServer, tmp_path: Path): """ - Another example test interacting with the server. + Test over-saturation detection using the --default-over-saturation flag. """ - report_path = Path("tests/e2e/over_saturated_benchmarks.json") - rate = 100 + report_name = "over_saturated_benchmarks.json" + report_path = tmp_path / report_name + rate = 10 # Create and configure the guidellm client - client = GuidellmClient(target=server.get_url(), output_path=report_path) + client = GuidellmClient( + target=server.get_url(), + output_dir=tmp_path, + outputs=report_name, + ) - cleanup_report_file(report_path) - # Start the benchmark + # Start the benchmark with --default-over-saturation flag client.start_benchmark( rate=rate, max_seconds=20, - stop_over_saturated=True, + over_saturation={"enabled": True, "min_seconds": 0}, extra_env={ - "GUIDELLM__CONSTRAINT_OVER_SATURATION_MIN_SECONDS": "0", "GOMAXPROCS": "1", }, ) @@ -69,7 +70,55 @@ def test_over_saturated_benchmark(server: VllmSimServer): # Check that the max duration constraint was triggered assert_constraint_triggered( - benchmark, "stop_over_saturated", {"is_over_saturated": True} + benchmark, "over_saturation", {"is_over_saturated": True} + ) + + +@pytest.mark.timeout(60) +def test_over_saturated_benchmark_with_dict_config( + server: VllmSimServer, tmp_path: Path +): + """ + Test over-saturation detection with dictionary configuration instead of boolean. + """ + report_name = "over_saturated_benchmarks_dict.json" + report_path = tmp_path / report_name + rate = 10 + + # Create and configure the guidellm client + client = GuidellmClient( + target=server.get_url(), + output_dir=tmp_path, + outputs=report_name, + ) + + # Start the benchmark with dictionary configuration for over-saturation + client.start_benchmark( + rate=rate, + max_seconds=20, + over_saturation={ + "enabled": True, + "min_seconds": 0, + "max_window_seconds": 120.0, + "moe_threshold": 2.0, + "minimum_window_size": 5, + }, + extra_env={ + "GOMAXPROCS": "1", + }, ) - cleanup_report_file(report_path) + # Wait for the benchmark to complete + client.wait_for_completion(timeout=55) + + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) + + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Check that the over-saturation constraint was triggered + assert_constraint_triggered( + benchmark, "over_saturation", {"is_over_saturated": True} + ) diff --git a/tests/e2e/test_successful_benchmark.py b/tests/e2e/test_successful_benchmark.py index 3642eaf43..8703882b2 100644 --- a/tests/e2e/test_successful_benchmark.py +++ b/tests/e2e/test_successful_benchmark.py @@ -9,7 +9,6 @@ assert_constraint_triggered, assert_no_python_exceptions, assert_successful_requests_fields, - cleanup_report_file, load_benchmark_report, ) from tests.e2e.vllm_sim_server import VllmSimServer @@ -37,90 +36,87 @@ def server(): @pytest.mark.timeout(30) @pytest.mark.sanity -def test_max_seconds_benchmark(server: VllmSimServer): +def test_max_seconds_benchmark(server: VllmSimServer, tmp_path: Path): """ Test that the max seconds constraint is properly triggered. """ - report_path = Path("tests/e2e/max_duration_benchmarks.json") + report_name = "max_duration_benchmarks.json" + report_path = tmp_path / report_name rate = 4 duration = 5 max_seconds = duration # Create and configure the guidellm client - client = GuidellmClient(target=server.get_url(), output_path=report_path) + client = GuidellmClient( + target=server.get_url(), + output_dir=tmp_path, + outputs=report_name, + ) - try: - # Start the benchmark - client.start_benchmark( - rate=rate, - max_seconds=max_seconds, - ) + # Start the benchmark + client.start_benchmark( + rate=rate, + max_seconds=max_seconds, + ) - # Wait for the benchmark to complete - client.wait_for_completion(timeout=30) + # Wait for the benchmark to complete + client.wait_for_completion(timeout=30) - # Assert no Python exceptions occurred - assert_no_python_exceptions(client.stderr) + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) - # Load and validate the report - report = load_benchmark_report(report_path) - benchmark = report["benchmarks"][0] + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] - # Check that the max duration constraint was triggered - assert_constraint_triggered( - benchmark, "max_seconds", {"duration_exceeded": True} - ) + # Check that the max duration constraint was triggered + assert_constraint_triggered(benchmark, "max_seconds", {"duration_exceeded": True}) - # Validate successful requests have all expected fields - successful_requests = benchmark["requests"]["successful"] - assert_successful_requests_fields(successful_requests) - - finally: - cleanup_report_file(report_path) + # Validate successful requests have all expected fields + successful_requests = benchmark["requests"]["successful"] + assert_successful_requests_fields(successful_requests) @pytest.mark.timeout(30) @pytest.mark.sanity -def test_max_requests_benchmark(server: VllmSimServer): +def test_max_requests_benchmark(server: VllmSimServer, tmp_path: Path): """ Test that the max requests constraint is properly triggered. """ - report_path = Path("tests/e2e/max_number_benchmarks.json") + report_name = "max_number_benchmarks.json" + report_path = tmp_path / report_name rate = 4 duration = 5 max_requests = rate * duration # Create and configure the guidellm client - client = GuidellmClient(target=server.get_url(), output_path=report_path) + client = GuidellmClient( + target=server.get_url(), + output_dir=tmp_path, + outputs=report_name, + ) - try: - # Start the benchmark - client.start_benchmark( - rate=rate, - max_requests=max_requests, - ) - - # Wait for the benchmark to complete - client.wait_for_completion(timeout=30) - - # Assert no Python exceptions occurred - assert_no_python_exceptions(client.stderr) - - # Load and validate the report - report = load_benchmark_report(report_path) - benchmark = report["benchmarks"][0] - - # Check that the max requests constraint was triggered - assert_constraint_triggered( - benchmark, "max_requests", {"processed_exceeded": True} - ) - - # Validate successful requests have all expected fields - successful_requests = benchmark["requests"]["successful"] - assert len(successful_requests) == max_requests, ( - f"Expected {max_requests} successful requests, " - f"got {len(successful_requests)}" - ) - assert_successful_requests_fields(successful_requests) + # Start the benchmark + client.start_benchmark( + rate=rate, + max_requests=max_requests, + ) - finally: - cleanup_report_file(report_path) + # Wait for the benchmark to complete + client.wait_for_completion(timeout=30) + + # Assert no Python exceptions occurred + assert_no_python_exceptions(client.stderr) + + # Load and validate the report + report = load_benchmark_report(report_path) + benchmark = report["benchmarks"][0] + + # Check that the max requests constraint was triggered + assert_constraint_triggered(benchmark, "max_requests", {"processed_exceeded": True}) + + # Validate successful requests have all expected fields + successful_requests = benchmark["requests"]["successful"] + assert len(successful_requests) == max_requests, ( + f"Expected {max_requests} successful requests, got {len(successful_requests)}" + ) + assert_successful_requests_fields(successful_requests) diff --git a/tests/e2e/utils.py b/tests/e2e/utils.py index e63587e4e..55baa89d2 100644 --- a/tests/e2e/utils.py +++ b/tests/e2e/utils.py @@ -1,10 +1,12 @@ """Utilities for E2E tests.""" import json +import shlex import subprocess import sys import time from pathlib import Path +from typing import Any from loguru import logger @@ -24,7 +26,9 @@ def get_guidellm_executable() -> str: class GuidellmClient: """Wrapper class for running guidellm benchmark commands.""" - def __init__(self, target: str, output_path: Path): + def __init__( + self, target: str, output_dir: Path, outputs: str = "benchmarks.json" + ) -> None: """ Initialize the guidellm client. @@ -32,7 +36,8 @@ def __init__(self, target: str, output_path: Path): :param output_path: Path where the benchmark report will be saved """ self.target = target - self.output_path = output_path + self.output_dir = output_dir + self.outputs = outputs self.process: subprocess.Popen | None = None self.stdout: str | None = None self.stderr: str | None = None @@ -44,7 +49,7 @@ def start_benchmark( max_seconds: int | None = None, max_requests: int | None = None, max_error_rate: float | None = None, - stop_over_saturated: bool | None = False, + over_saturation: dict[str, Any] | None = None, data: str = "prompt_tokens=256,output_tokens=128", processor: str = "gpt2", additional_args: str = "", @@ -53,24 +58,25 @@ def start_benchmark( """ Start a guidellm benchmark command. - :param rate_type: Type of rate control (constant, etc.) + :param profile: Type of rate control (constant, etc.) :param rate: Request rate :param max_seconds: Maximum duration in seconds :param max_requests: Maximum number of requests :param max_error_rate: Maximum error rate before stopping - :param stop_over_saturated: Whether to stop the benchmark if the model is - over-saturated. + :param over_saturation: Over-saturation detection configuration (dict). + Passed as JSON string to --over-saturation CLI argument. :param data: Data configuration string :param processor: Processor/tokenizer to use :param additional_args: Additional command line arguments + :param extra_env: Additional environment variables to set """ guidellm_exe = get_guidellm_executable() # Build command components cmd_parts = [ *([f"{k}={v}" for k, v in extra_env.items()] if extra_env else []), - "HF_HOME=/tmp/huggingface_cache", - f"{guidellm_exe} benchmark", + "HF_HOME=" + str(self.output_dir / "huggingface_cache"), + f"{guidellm_exe} benchmark run", f'--target "{self.target}"', f"--profile {profile}", f"--rate {rate}", @@ -85,14 +91,28 @@ def start_benchmark( if max_error_rate is not None: cmd_parts.append(f"--max-error-rate {max_error_rate}") - if stop_over_saturated: - cmd_parts.append("--stop-over-saturated") + if over_saturation is not None: + if isinstance(over_saturation, dict): + # Use --default-over-saturation flag for empty dict (defaults) + if over_saturation == {}: + cmd_parts.append("--default-over-saturation") + else: + # Escape the JSON string properly for shell + json_str = json.dumps(over_saturation) + # Use shlex.quote to properly escape for shell + cmd_parts.append(f"--over-saturation {shlex.quote(json_str)}") + else: + raise TypeError( + f"over_saturation must be a dict or None, " + f"got {type(over_saturation)}" + ) cmd_parts.extend( [ f'--data "{data}"', f'--processor "{processor}"', - f"--output-path {self.output_path}", + f"--output-dir {self.output_dir}", + f"--outputs {self.outputs}", ] ) @@ -104,7 +124,7 @@ def start_benchmark( logger.info(f"Client command: {command}") self.process = subprocess.Popen( # noqa: S603 - ["/bin/bash", "-c", command], + ["/bin/sh", "-c", command], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, @@ -321,13 +341,3 @@ def assert_constraint_triggered( assert actual_value == expected_value, ( f"Expected {key}={expected_value}, got {actual_value}" ) - - -def cleanup_report_file(report_path: Path) -> None: - """ - Clean up the report file if it exists. - - :param report_path: Path to the report file to remove - """ - if report_path.exists(): - report_path.unlink() diff --git a/tests/e2e/vllm-sim-macos.Dockerfile b/tests/e2e/vllm-sim-macos.Dockerfile new file mode 100644 index 000000000..a274c00e9 --- /dev/null +++ b/tests/e2e/vllm-sim-macos.Dockerfile @@ -0,0 +1,17 @@ +FROM golang AS base + +WORKDIR /app + +ARG BUILDARCH + +RUN apt-get update && \ + apt-get install -y libzmq3-dev pkg-config && \ + git clone https://github.com/llm-d/llm-d-inference-sim.git && \ + cd llm-d-inference-sim && \ + git checkout v0.3.0 && \ + GOOS=darwin GOARCH=${BUILDARCH} make build + +WORKDIR /app/llm-d-inference-sim + +FROM scratch +COPY --from=base /app/llm-d-inference-sim/bin /bin diff --git a/tests/unit/scheduler/test_over_saturation.py b/tests/unit/scheduler/test_over_saturation.py new file mode 100644 index 000000000..6e4a2955f --- /dev/null +++ b/tests/unit/scheduler/test_over_saturation.py @@ -0,0 +1,593 @@ +"""Unit tests for over-saturation constraint implementation.""" + +import inspect +import time + +import pytest +from pydantic import ValidationError + +from guidellm.scheduler import ( + Constraint, + ConstraintInitializer, + ConstraintsInitializerFactory, + OverSaturationConstraint, + OverSaturationConstraintInitializer, + PydanticConstraintInitializer, + SchedulerState, + SchedulerUpdateAction, + SerializableConstraintInitializer, +) +from guidellm.schemas import RequestInfo, RequestTimings + + +class TestOverSaturationConstraintInternal: + """Test the OverSaturationConstraint internal functionality.""" + + @pytest.fixture( + params=[ + {"minimum_duration": 30.0, "maximum_window_seconds": 120.0}, + {"minimum_duration": 10.0, "maximum_window_seconds": 60.0}, + {"minimum_duration": 60.0, "maximum_window_seconds": 240.0}, + ] + ) + def valid_instances(self, request): + """Create OverSaturationConstraint instances with valid parameters.""" + constructor_args = request.param + instance = OverSaturationConstraint(**constructor_args, enabled=True) + return instance, constructor_args + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test OverSaturationConstraint initialization with valid parameters.""" + instance, constructor_args = valid_instances + + for key, value in constructor_args.items(): + assert hasattr(instance, key) + assert getattr(instance, key) == value + + @pytest.mark.smoke + def test_initialization_defaults(self): + """Test that OverSaturationConstraint has correct default values.""" + constraint = OverSaturationConstraint(enabled=True) + + assert constraint.minimum_duration == 30.0 + assert constraint.minimum_ttft == 2.5 + assert constraint.maximum_window_seconds == 120.0 + assert constraint.moe_threshold == 2.0 + assert constraint.maximum_window_ratio == 0.75 + assert constraint.minimum_window_size == 5 + assert constraint.confidence == 0.95 + assert constraint.eps == 1e-12 + + @pytest.mark.smoke + def test_reset(self, valid_instances): + """Test that reset method properly initializes constraint state.""" + constraint, _ = valid_instances + constraint.reset() + + assert constraint.duration == 0.0 + assert constraint.started_requests == [] + assert constraint.finished_requests == [] + assert constraint.ttft_violations_counter == 0 + assert constraint.total_finished_ever == 0 + assert constraint.total_started_ever == 0 + assert hasattr(constraint, "concurrent_slope_checker") + assert hasattr(constraint, "ttft_slope_checker") + + @pytest.mark.sanity + def test_window_management_through_constraint(self): + """Test that constraint properly manages window sizes through usage.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + maximum_window_seconds=100.0, + maximum_window_ratio=0.5, + enabled=True, + ) + start_time = time.time() + + # Add many requests through constraint calls + for i in range(100): + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time - i, + processing_requests=i, + ) + request = RequestInfo( + request_id=f"test-{i}", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time - i, + ) + constraint(state, request) + + # Check that window management is working (through internal state) + # The constraint should have pruned old requests + assert len(constraint.started_requests) <= 50 # Should be limited by ratio + + +class TestOverSaturationConstraint: + """Test the OverSaturationConstraint implementation.""" + + @pytest.fixture + def constraint(self): + """Create a constraint for testing.""" + return OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + @pytest.fixture( + params=[ + {"enabled": True}, + {"enabled": False}, + ] + ) + def valid_instances(self, request): + """Create OverSaturationConstraint instances with valid parameters.""" + constructor_args = request.param + instance = OverSaturationConstraint( + minimum_duration=0.0, + minimum_window_size=3, + **constructor_args, + ) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_constraint_protocol(self, valid_instances): + """Test that OverSaturationConstraint satisfies the Constraint protocol.""" + constraint, _ = valid_instances + assert isinstance(constraint, Constraint) + + @pytest.mark.smoke + def test_protocol_method_signature(self): + """Test that OverSaturationConstraint has the correct method signature.""" + constraint = OverSaturationConstraint(enabled=True) + call_method = constraint.__call__ + sig = inspect.signature(call_method) + + expected_params = ["state", "request_info"] + assert list(sig.parameters.keys()) == expected_params + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test OverSaturationConstraint initialization with valid parameters.""" + constraint, constructor_args = valid_instances + + assert constraint.enabled == constructor_args["enabled"] + + @pytest.mark.sanity + def test_constraint_returns_continue_when_not_saturated(self, constraint): + """Test constraint returns continue when not over-saturated.""" + start_time = time.time() + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + request = RequestInfo( + request_id="test-1", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + assert isinstance(action.metadata, dict) + assert "is_over_saturated" in action.metadata + + @pytest.mark.sanity + def test_constraint_with_completed_request(self, constraint): + """Test constraint with completed request including timings.""" + start_time = time.time() + + # Create timings with first_iteration + timings = RequestTimings( + request_start=start_time + 0.1, first_iteration=start_time + 0.2 + ) + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + request = RequestInfo( + request_id="test-1", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + timings=timings, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert "ttft_slope" in action.metadata + assert "ttft_n" in action.metadata + + @pytest.mark.sanity + def test_constraint_stops_when_over_saturated(self, constraint): + """Test constraint stops when over-saturated and flag is enabled.""" + start_time = time.time() + + # Simulate over-saturation by creating positive slopes through constraint calls + # Add many started requests with increasing concurrent count + for i in range(20): + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time - i, + processing_requests=i * 2, + ) + request = RequestInfo( + request_id=f"test-{i}", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time - i, + ) + constraint(state, request) + + # Add finished requests with increasing TTFT + for i in range(20): + timings = RequestTimings( + request_start=start_time - i - 10.0, + first_iteration=start_time - i - 10.0 + (1.0 + i * 0.1), + ) + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time - i - 10.0, + processing_requests=5, + ) + request = RequestInfo( + request_id=f"test-finished-{i}", + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time - i - 10.0, + timings=timings, + ) + constraint(state, request) + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=40, + ) + + request = RequestInfo( + request_id="test-1", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # If over-saturated, should stop (but depends on slope detection) + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + # The exact action depends on whether detection triggers + assert action.request_queuing in ["continue", "stop"] + assert "is_over_saturated" in action.metadata + + @pytest.mark.sanity + def test_constraint_never_stops_when_flag_disabled(self): + """Test constraint never stops when enabled is False.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + minimum_window_size=3, + enabled=False, + ) + start_time = time.time() + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=100, # High concurrent requests + ) + + request = RequestInfo( + request_id="test-1", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # Even if over-saturated, should continue when flag is False + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + + +class TestOverSaturationConstraintInitializer: + """Test the OverSaturationConstraintInitializer implementation.""" + + @pytest.fixture( + params=[ + {"enabled": True}, + {"enabled": False}, + { + "enabled": True, + "min_seconds": 10.0, + "max_window_seconds": 60.0, + }, + ] + ) + def valid_instances(self, request): + """Create OverSaturationConstraintInitializer with valid parameters.""" + constructor_args = request.param + instance = OverSaturationConstraintInitializer(**constructor_args) + return instance, constructor_args + + @pytest.mark.smoke + def test_is_pydantic_constraint_initializer(self, valid_instances): + """Test that initializer is a PydanticConstraintInitializer.""" + instance, _ = valid_instances + assert isinstance(instance, PydanticConstraintInitializer) + assert isinstance(instance, SerializableConstraintInitializer) + + @pytest.mark.smoke + def test_is_constraint_initializer_protocol(self, valid_instances): + """Test that initializer satisfies ConstraintInitializer protocol.""" + instance, _ = valid_instances + assert isinstance(instance, ConstraintInitializer) + + @pytest.mark.smoke + def test_initialization_valid(self, valid_instances): + """Test that initializer can be initialized with valid parameters.""" + instance, constructor_args = valid_instances + + assert instance.type_ == "over_saturation" + assert instance.enabled == constructor_args["enabled"] + + if "min_seconds" in constructor_args: + assert instance.min_seconds == constructor_args["min_seconds"] + if "max_window_seconds" in constructor_args: + assert instance.max_window_seconds == constructor_args["max_window_seconds"] + + @pytest.mark.sanity + def test_initialization_invalid(self): + """Test that initializer rejects invalid parameters.""" + # Invalid type for enabled + with pytest.raises(ValidationError): + OverSaturationConstraintInitializer(enabled="invalid") + + # Invalid type for min_seconds + with pytest.raises(ValidationError): + OverSaturationConstraintInitializer(enabled=True, min_seconds="invalid") + + @pytest.mark.smoke + def test_create_constraint(self, valid_instances): + """Test that create_constraint returns OverSaturationConstraint.""" + instance, _ = valid_instances + constraint = instance.create_constraint() + + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled == instance.enabled + + @pytest.mark.smoke + def test_validated_kwargs(self): + """Test validated_kwargs method with various inputs.""" + + # Test with empty dict (uses defaults, enabled=True by default) + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation={} + ) + # enabled defaults to True in the Pydantic model + assert result == {} + + # Test with dict input (enabled=False) + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation={"enabled": False} + ) + assert result["enabled"] is False + + # Test with dict input with min_seconds + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation={"enabled": True, "min_seconds": 20.0} + ) + assert result["enabled"] is True + assert result["min_seconds"] == 20.0 + + # Test with None (should return enabled=False) + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation=None + ) + assert result["enabled"] is False + + # Test with aliases + result = OverSaturationConstraintInitializer.validated_kwargs( + detect_saturation={"enabled": True} + ) + assert result["enabled"] is True + + @pytest.mark.smoke + def test_marshalling(self, valid_instances): + """Test that initializer can be serialized and deserialized.""" + instance, constructor_args = valid_instances + + data = instance.model_dump() + assert data["type_"] == "over_saturation" + assert data["enabled"] == constructor_args["enabled"] + + reconstructed = OverSaturationConstraintInitializer.model_validate(data) + assert reconstructed.enabled == instance.enabled + + @pytest.mark.smoke + def test_factory_registration(self): + """Test that initializer is properly registered with expected aliases.""" + expected_aliases = [ + "over_saturation", + "detect_saturation", + ] + + for alias in expected_aliases: + assert ConstraintsInitializerFactory.is_registered(alias) + registered_class = ConstraintsInitializerFactory.get_registered_object( + alias + ) + assert registered_class == OverSaturationConstraintInitializer + + @pytest.mark.smoke + @pytest.mark.parametrize("alias", ["over_saturation", "detect_saturation"]) + def test_factory_creation_with_aliases(self, alias): + """Test factory creation using different aliases.""" + # Test with dict configuration using kwargs + constraint = ConstraintsInitializerFactory.create_constraint( + alias, enabled=True + ) + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled is True + + # Test with empty dict (uses defaults, enabled=True by default) + constraint = ConstraintsInitializerFactory.create_constraint(alias, {}) + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled is True + + # Test with dict value (enabled=False) + constraint = ConstraintsInitializerFactory.create_constraint( + alias, {"enabled": False} + ) + assert isinstance(constraint, OverSaturationConstraint) + assert constraint.enabled is False + + @pytest.mark.smoke + def test_factory_resolve_methods(self): + """Test factory resolve methods with various input formats.""" + # Test with dict config + resolved = ConstraintsInitializerFactory.resolve( + {"over_saturation": {"enabled": True}} + ) + assert isinstance(resolved["over_saturation"], OverSaturationConstraint) + assert resolved["over_saturation"].enabled is True + + # Test with dict value + resolved = ConstraintsInitializerFactory.resolve( + {"detect_saturation": {"enabled": True}} + ) + assert isinstance(resolved["detect_saturation"], OverSaturationConstraint) + assert resolved["detect_saturation"].enabled is True + + # Test with instance + instance = OverSaturationConstraintInitializer(enabled=False) + constraint_instance = instance.create_constraint() + resolved = ConstraintsInitializerFactory.resolve( + {"over_saturation": constraint_instance} + ) + assert resolved["over_saturation"] is constraint_instance + + @pytest.mark.smoke + def test_functional_constraint_creation(self): + """Test that created constraints are functionally correct.""" + constraint = ConstraintsInitializerFactory.create_constraint( + "over_saturation", enabled=True + ) + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + created_requests=5, + processed_requests=5, + processing_requests=3, + ) + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + # Should continue when not over-saturated + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + assert "is_over_saturated" in action.metadata + + +class TestSlopeChecker: + """Test the SlopeChecker implementation used by OverSaturationDetector.""" + + @pytest.fixture + def slope_checker(self): + """Create a SlopeChecker instance for testing.""" + from guidellm.scheduler.constraints.saturation import ( + SlopeChecker, + ) + + return SlopeChecker(moe_threshold=1.0, confidence=0.95) + + @pytest.mark.smoke + def test_initialization(self, slope_checker): + """Test SlopeChecker initialization.""" + assert slope_checker.n == 0 + assert slope_checker.sum_x == 0.0 + assert slope_checker.sum_y == 0.0 + assert slope_checker.moe_threshold == 1.0 + assert slope_checker.confidence == 0.95 + + @pytest.mark.sanity + def test_add_and_remove_data_points(self, slope_checker): + """Test adding and removing data points.""" + # Add data points + slope_checker.add_data_point(1.0, 2.0) + slope_checker.add_data_point(2.0, 4.0) + slope_checker.add_data_point(3.0, 6.0) + + assert slope_checker.n == 3 + assert slope_checker.sum_x == 6.0 + assert slope_checker.sum_y == 12.0 + + # Remove data point + slope_checker.remove_data_point(1.0, 2.0) + + assert slope_checker.n == 2 + assert slope_checker.sum_x == 5.0 + assert slope_checker.sum_y == 10.0 + + @pytest.mark.sanity + def test_check_slope_with_positive_slope(self, slope_checker): + """Test check_slope with clear positive slope.""" + # Create data with clear positive slope + for i in range(10): + slope_checker.add_data_point(float(i), float(i * 2)) + + result = slope_checker.check_slope(10.0) + assert result is True + assert slope_checker.slope is not None + assert slope_checker.slope > 0 + assert slope_checker.margin_of_error is not None + + @pytest.mark.sanity + def test_check_slope_requires_minimum_samples(self, slope_checker): + """Test that check_slope requires minimum samples.""" + # Not enough samples + slope_checker.add_data_point(1.0, 2.0) + result = slope_checker.check_slope(1.0) + assert result is False + + # Still not enough with 2 points + slope_checker.add_data_point(2.0, 4.0) + result = slope_checker.check_slope(2.0) + assert result is False + + # Should work with 3+ points + slope_checker.add_data_point(3.0, 6.0) + result = slope_checker.check_slope(3.0) + # Might be True or False depending on confidence intervals diff --git a/tests/unit/scheduler/test_over_saturation_comprehensive.py b/tests/unit/scheduler/test_over_saturation_comprehensive.py new file mode 100644 index 000000000..8bd641760 --- /dev/null +++ b/tests/unit/scheduler/test_over_saturation_comprehensive.py @@ -0,0 +1,870 @@ +"""Comprehensive unit tests for over-saturation constraint implementation. + +This module provides thorough testing to validate that over-saturation detection +and stopping features work correctly under various conditions and edge cases. +""" + +import math +import time +from unittest.mock import patch + +import pytest + +from guidellm.scheduler import ( + OverSaturationConstraint, + OverSaturationConstraintInitializer, + SchedulerState, + SchedulerUpdateAction, +) +from guidellm.scheduler.constraints.saturation import ( + SlopeChecker, + approx_t_ppf, +) +from guidellm.schemas import RequestInfo, RequestTimings + + +class TestSlopeCheckerStatisticalAccuracy: + """Test the statistical accuracy of SlopeChecker implementation.""" + + @pytest.mark.sanity + def test_approx_t_ppf_accuracy(self): + """Test that approx_t_ppf produces reasonable approximations.""" + # Test known values for t-distribution + # For df=10, p=0.975 (95% confidence, two-tailed), t ≈ 2.228 + result = approx_t_ppf(0.975, 10) + assert 2.0 < result < 2.5, f"Expected ~2.228, got {result}" + + # For df=30, p=0.975, t ≈ 2.042 + result = approx_t_ppf(0.975, 30) + assert 1.9 < result < 2.2, f"Expected ~2.042, got {result}" + + # For large df, should approach normal distribution (z=1.96) + result = approx_t_ppf(0.975, 1000) + assert 1.8 < result < 2.1, f"Expected ~1.96, got {result}" + + @pytest.mark.sanity + def test_approx_t_ppf_edge_cases(self): + """Test approx_t_ppf with edge cases.""" + # Very small df + result = approx_t_ppf(0.975, 1) + assert result > 5.0, "t-value should be large for df=1" + + # Invalid df should return NaN + result = approx_t_ppf(0.975, 0) + assert math.isnan(result) + + result = approx_t_ppf(0.975, -1) + assert math.isnan(result) + + @pytest.mark.smoke + def test_slope_calculation_perfect_line(self): + """Test slope calculation with perfect linear data.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Perfect line: y = 2x + 1 + for i in range(10): + x = float(i) + y = 2.0 * x + 1.0 + checker.add_data_point(x, y) + + result = checker.check_slope(10.0) + assert result is True + assert abs(checker.slope - 2.0) < 0.001, ( + f"Expected slope ~2.0, got {checker.slope}" + ) + + @pytest.mark.smoke + def test_slope_calculation_zero_slope(self): + """Test slope calculation with horizontal line.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Horizontal line: y = 5 + for i in range(10): + x = float(i) + y = 5.0 + checker.add_data_point(x, y) + + result = checker.check_slope(10.0) + # Should not detect positive slope + if result: + assert checker.slope <= 0.1, f"Slope should be ~0, got {checker.slope}" + + @pytest.mark.sanity + def test_slope_calculation_negative_slope(self): + """Test slope calculation with negative slope.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Negative slope: y = -1.5x + 10 + for i in range(10): + x = float(i) + y = -1.5 * x + 10.0 + checker.add_data_point(x, y) + + result = checker.check_slope(10.0) + # Should not detect positive slope + assert result is False or checker.slope <= 0 + + @pytest.mark.sanity + def test_slope_calculation_with_noise(self): + """Test slope calculation with noisy data.""" + import random + + random.seed(42) # Reproducible results + + checker = SlopeChecker(moe_threshold=1.0, confidence=0.90) + + # Positive slope with noise: y = 1.5x + noise + for i in range(50): + x = float(i) + noise = random.uniform(-2.0, 2.0) + y = 1.5 * x + noise + checker.add_data_point(x, y) + + result = checker.check_slope(50.0) + if result: + assert 1.0 < checker.slope < 2.0, ( + f"Expected slope ~1.5, got {checker.slope}" + ) + + @pytest.mark.sanity + def test_margin_of_error_calculation(self): + """Test that margin of error is calculated correctly.""" + checker = SlopeChecker(moe_threshold=0.5, confidence=0.95) + + # Add data with known properties + for i in range(20): + x = float(i) + y = 2.0 * x + 1.0 + checker.add_data_point(x, y) + + result = checker.check_slope(20.0) + assert result is True + assert checker.margin_of_error is not None + assert checker.margin_of_error >= 0 + # For perfect data, margin of error should be very small + assert checker.margin_of_error < 0.1 + + +class TestOverSaturationConstraintRobustness: + """Test the robustness of OverSaturationConstraint under various conditions.""" + + @pytest.mark.sanity + def test_constraint_with_empty_data(self): + """Test constraint behavior with no data.""" + constraint = OverSaturationConstraint(minimum_duration=0.0, enabled=True) + + # Should not alert with no data + assert constraint._check_alert() is False + + # Should handle update_duration gracefully + constraint._update_duration(100.0) + assert constraint._check_alert() is False + + @pytest.mark.sanity + def test_constraint_with_single_request(self): + """Test constraint behavior with single request.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=1, enabled=True + ) + + constraint._add_started({"concurrent_requests": 5, "duration": 1.0}) + constraint._add_finished({"ttft": 2.0, "duration": 2.0}) + constraint._update_duration(10.0) + + # Should not alert with insufficient data + assert constraint._check_alert() is False + + @pytest.mark.sanity + def test_constraint_with_identical_values(self): + """Test constraint with identical values (zero variance).""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Add identical values + for i in range(10): + constraint._add_started({"concurrent_requests": 5, "duration": float(i)}) + constraint._add_finished({"ttft": 1.0, "duration": float(i)}) + + constraint._update_duration(20.0) + result = constraint._check_alert() + + # Should not alert for flat data + assert result is False + + @pytest.mark.sanity + def test_constraint_extreme_values(self): + """Test constraint with extreme values.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Add extreme values + values = [0.1, 1000.0, 0.01, 5000.0, 0.001] + for i, val in enumerate(values): + constraint._add_started( + {"concurrent_requests": int(val), "duration": float(i)} + ) + constraint._add_finished({"ttft": val, "duration": float(i)}) + + constraint._update_duration(20.0) + # Should handle without crashing + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_constraint_precision_edge_cases(self): + """Test constraint with floating point precision edge cases.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Very small increments + base = 1e10 + for i in range(10): + constraint._add_started( + {"concurrent_requests": 5, "duration": base + i * 1e-10} + ) + constraint._add_finished({"ttft": 1.0, "duration": base + i * 1e-10}) + + constraint._update_duration(base + 100.0) + # Should handle without numerical issues + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_constraint_window_management_stress(self): + """Test constraint window management under stress.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + maximum_window_seconds=10.0, + minimum_window_size=5, + enabled=True, + ) + + # Add many requests over time + for i in range(1000): + duration = float(i * 0.1) # 100 seconds total + constraint._add_started( + {"concurrent_requests": i % 50, "duration": duration} + ) + constraint._add_finished({"ttft": (i % 100) * 0.01, "duration": duration}) + + # Periodic window updates + if i % 100 == 0: + constraint._update_duration(duration + 5.0) + + # Should maintain reasonable window size + assert len(constraint.started_requests) <= 200 # Should be pruned + assert len(constraint.finished_requests) <= 200 + + +class TestOverSaturationConstraintRealisticScenarios: + """Test detector with realistic request patterns.""" + + @pytest.mark.sanity + def test_gradual_performance_degradation(self): + """Test detection of gradual performance degradation.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + moe_threshold=1.5, + enabled=True, + ) + + # Simulate gradual degradation + for i in range(50): + # Gradually increasing concurrent requests + concurrent = 10 + i * 0.5 + # Gradually increasing TTFT + ttft = 1.0 + i * 0.1 + duration = float(i) + + constraint._add_started( + {"concurrent_requests": int(concurrent), "duration": duration} + ) + constraint._add_finished({"ttft": ttft, "duration": duration}) + + constraint._update_duration(60.0) + result = constraint._check_alert() + + # Should detect the degradation + assert result is True, "Should detect gradual performance degradation" + + @pytest.mark.sanity + def test_sudden_load_spike(self): + """Test detection of sudden load spike.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + moe_threshold=1.0, + enabled=True, + ) + + # Normal operations first + for i in range(20): + constraint._add_started({"concurrent_requests": 5, "duration": float(i)}) + constraint._add_finished({"ttft": 1.0, "duration": float(i)}) + + # Sudden spike + for i in range(20, 40): + constraint._add_started({"concurrent_requests": 50, "duration": float(i)}) + constraint._add_finished({"ttft": 5.0, "duration": float(i)}) + + constraint._update_duration(50.0) + result = constraint._check_alert() + + # Should detect the spike + assert result is True, "Should detect sudden load spike" + + @pytest.mark.sanity + def test_variable_but_stable_performance(self): + """Test that variable but stable performance doesn't trigger false positives.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + moe_threshold=2.0, + enabled=True, + ) + + import random + + random.seed(123) # Reproducible + + # Variable but centered around stable values + for i in range(100): + concurrent = 15 + random.randint(-5, 5) # 10-20 range + ttft = 2.0 + random.uniform(-0.5, 0.5) # 1.5-2.5 range + duration = float(i) + + constraint._add_started( + {"concurrent_requests": concurrent, "duration": duration} + ) + constraint._add_finished({"ttft": ttft, "duration": duration}) + + constraint._update_duration(120.0) + result = constraint._check_alert() + + # Should not trigger false positive + assert result is False, ( + "Should not trigger false positive for stable performance" + ) + + @pytest.mark.sanity + def test_recovery_after_degradation(self): + """Test that detector handles recovery after degradation.""" + constraint = OverSaturationConstraint( + minimum_duration=5.0, + minimum_window_size=10, + maximum_window_seconds=30.0, + enabled=True, + ) + + # Initial degradation + for i in range(20): + concurrent = 10 + i * 2 # Increasing load + ttft = 1.0 + i * 0.2 # Increasing TTFT + constraint._add_started( + {"concurrent_requests": concurrent, "duration": float(i)} + ) + constraint._add_finished({"ttft": ttft, "duration": float(i)}) + + constraint._update_duration(25.0) + degradation_result = constraint._check_alert() + + # Add recovery period - improved performance + for i in range(40, 60): + constraint._add_started({"concurrent_requests": 5, "duration": float(i)}) + constraint._add_finished({"ttft": 0.8, "duration": float(i)}) + + constraint._update_duration(65.0) + recovery_result = constraint._check_alert() + + # Should detect degradation initially, then not alert during recovery + # (depending on window management) + assert degradation_result in [True, False] # Could go either way + # After recovery with window management, should be less likely to alert + if len(constraint.finished_requests) < 15: # If old data was purged + assert recovery_result is False, "Should not alert after recovery" + + +class TestOverSaturationConstraintIntegration: + """Test integration between constraint and detector with complex scenarios.""" + + def create_realistic_constraint(self) -> OverSaturationConstraint: + """Create a constraint with realistic settings.""" + return OverSaturationConstraint( + minimum_duration=10.0, + minimum_window_size=5, + maximum_window_seconds=60.0, + moe_threshold=1.5, + confidence=0.90, + enabled=True, + ) + + @pytest.mark.sanity + def test_constraint_metadata_completeness(self): + """Test that constraint provides complete metadata.""" + constraint = self.create_realistic_constraint() + start_time = time.time() + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=10, + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + + # Verify metadata completeness + required_fields = [ + "is_over_saturated", + "concurrent_slope", + "concurrent_n", + "ttft_slope", + "ttft_n", + "ttft_violations", # Correct field name + # Note: total_started_ever, total_finished_ever, + # window sizes not in metadata + ] + + for field in required_fields: + assert field in action.metadata, f"Missing metadata field: {field}" + + @pytest.mark.sanity + def test_constraint_with_realistic_request_flow(self): + """Test constraint with realistic request flow.""" + constraint = self.create_realistic_constraint() + start_time = time.time() + actions = [] + + # Simulate 60 seconds of requests + for second in range(60): + current_time = start_time + second + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=10 + second, # Gradually increasing load + ) + + # Mix of request statuses + for req_num in range(3): # 3 requests per second + request_id = f"req-{second}-{req_num}" + + if req_num == 0: # Completed request + timings = RequestTimings( + request_start=current_time - 2.0, + first_iteration=current_time + - 2.0 + + (second * 0.05), # Gradually slower + ) + request = RequestInfo( + request_id=request_id, + status="completed", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + timings=timings, + ) + else: # In progress request + request = RequestInfo( + request_id=request_id, + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + actions.append((second, action)) + + # Analyze results + stop_actions = [a for s, a in actions if a.request_queuing == "stop"] + + # Should eventually detect over-saturation + if len(stop_actions) > 0: + first_stop_second = min( + s for s, a in actions if a.request_queuing == "stop" + ) + assert first_stop_second >= 10, "Should not stop before minimum duration" + + @pytest.mark.sanity + def test_constraint_disabled_never_stops(self): + """Test that disabled constraint never stops regardless of load.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + minimum_window_size=3, + enabled=False, # Disabled + ) + + # Add obviously over-saturated data + for i in range(50): + constraint._add_started( + {"concurrent_requests": i * 10, "duration": float(i)} + ) + constraint._add_finished({"ttft": i * 2.0, "duration": float(i)}) + + constraint._update_duration(60.0) + + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=500, # Very high load + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + + # Should continue despite over-saturation + assert action.request_queuing == "continue" + assert action.request_processing == "continue" + assert action.metadata["is_over_saturated"] in [True, False] # Could be either + + +class TestOverSaturationConstraintPerformance: + """Test performance characteristics of the constraint.""" + + @pytest.mark.sanity + def test_detector_memory_usage(self): + """Test that detector manages memory properly.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + maximum_window_seconds=10.0, + minimum_window_size=5, + enabled=True, + ) + + # Add many requests + for i in range(10000): + duration = float(i * 0.01) # 100 seconds total + constraint._add_started({"concurrent_requests": 10, "duration": duration}) + constraint._add_finished({"ttft": 1.0, "duration": duration}) + + if i % 1000 == 0: + constraint._update_duration(duration + 5.0) + + # Memory should be bounded due to window management + assert len(constraint.started_requests) < 2000, "Started requests not bounded" + assert len(constraint.finished_requests) < 2000, "Finished requests not bounded" + + @pytest.mark.sanity + def test_constraint_computational_efficiency(self): + """Test that constraint operations remain efficient.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=10, enabled=True + ) + + # Add baseline data + for i in range(100): + constraint._add_started({"concurrent_requests": 10, "duration": float(i)}) + constraint._add_finished({"ttft": 1.0, "duration": float(i)}) + + constraint._update_duration(120.0) + + # Time multiple check_alert calls + start_time = time.time() + for _ in range(100): + constraint._check_alert() + elapsed = time.time() - start_time + + # Should complete quickly (< 1 second for 100 calls) + assert elapsed < 1.0, f"Detection too slow: {elapsed:.3f}s for 100 calls" + + +class TestOverSaturationConstraintInitializerRobustness: + """Test robustness of the constraint initializer.""" + + @pytest.mark.smoke + def test_initializer_parameter_validation(self): + """Test parameter validation in initializer.""" + # Valid parameters + initializer = OverSaturationConstraintInitializer( + enabled=True, + min_seconds=5.0, + max_window_seconds=30.0, + moe_threshold=1.5, + confidence=0.95, + ) + + constraint = initializer.create_constraint() + assert constraint.enabled is True + assert constraint.minimum_duration == 5.0 + assert constraint.maximum_window_seconds == 30.0 + + @pytest.mark.smoke + def test_initializer_with_extreme_parameters(self): + """Test initializer with extreme but valid parameters.""" + # Very permissive settings - only test parameters actually supported + initializer = OverSaturationConstraintInitializer( + enabled=True, + min_seconds=0.1, + max_window_seconds=3600.0, # 1 hour + ) + + constraint = initializer.create_constraint() + + assert constraint.minimum_duration == 0.1 + assert constraint.maximum_window_seconds == 3600.0 + # Note: moe_threshold and confidence may have default values + + @pytest.mark.smoke + def test_initializer_alias_precedence(self): + """Test alias precedence in validated_kwargs.""" + # Multiple aliases provided - should use the explicit one + result = OverSaturationConstraintInitializer.validated_kwargs( + over_saturation={"enabled": False}, # Explicit parameter + detect_saturation={"enabled": True}, # Alias + ) + + # detect_saturation should override over_saturation + assert result == {"enabled": True} + + @pytest.mark.smoke + def test_constraint_creation_with_mock_constraint(self): + """Test constraint creation with mocked constraint for isolation.""" + constraint = OverSaturationConstraint(enabled=True) + # Set up constraint state to simulate over-saturation + constraint.ttft_slope_checker.slope = 1.5 + constraint.ttft_slope_checker.margin_of_error = 0.3 + constraint.ttft_slope_checker.n = 10 + constraint.concurrent_slope_checker.slope = 2.0 + constraint.concurrent_slope_checker.margin_of_error = 0.5 + constraint.concurrent_slope_checker.n = 15 + constraint.ttft_violations_counter = 5 + constraint.duration = 30.0 # Set duration to pass minimum check + + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=10, + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + action = constraint(state, request) + + # Should provide metadata about saturation state + assert "is_over_saturated" in action.metadata + + +class TestOverSaturationEdgeCasesAndRegression: + """Test edge cases and regression scenarios.""" + + @pytest.mark.sanity + def test_detector_with_malformed_request_data(self): + """Test detector requires proper request data structure.""" + constraint = OverSaturationConstraint(minimum_duration=0.0, enabled=True) + + # Missing fields should raise KeyError + with pytest.raises(KeyError): + constraint._add_started({}) # Missing required fields + + with pytest.raises(KeyError): + constraint._add_finished({}) + + with pytest.raises(KeyError): + constraint._add_started({"concurrent_requests": 5}) # Missing duration + + with pytest.raises(KeyError): + constraint._add_finished({"ttft": 1.0}) # Missing duration + + # Valid data should work + constraint._add_started({"concurrent_requests": 5, "duration": 1.0}) + constraint._add_finished({"ttft": 1.0, "duration": 1.0}) + + constraint._update_duration(10.0) + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_constraint_with_missing_timings_data(self): + """Test constraint handles missing timings data gracefully.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + enabled=True, + ) + + start_time = time.time() + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + # Create request without timings (in_progress status) + request = RequestInfo( + request_id="test-request", + status="in_progress", # No timings expected for in_progress + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # Should not crash + action = constraint(state, request) + assert isinstance(action, SchedulerUpdateAction) + + @pytest.mark.sanity + def test_detector_concurrent_modification_safety(self): + """Test detector behavior under concurrent-like modifications.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, minimum_window_size=3, enabled=True + ) + + # Add requests + requests = [] + for i in range(20): + req = {"concurrent_requests": i, "duration": float(i)} + constraint._add_started(req) + requests.append(req) + + # Remove some while iterating (simulating concurrent access pattern) + for i in range(0, 10, 2): # Remove every other early request + constraint._remove_started(requests[i]) + + # Should still function + constraint._update_duration(25.0) + result = constraint._check_alert() + assert result in [True, False] + + @pytest.mark.sanity + def test_slope_checker_numerical_stability(self): + """Test SlopeChecker numerical stability with challenging data.""" + checker = SlopeChecker(moe_threshold=0.1, confidence=0.95) + + # Add data that could cause numerical instability + base = 1e15 # Very large numbers + for i in range(10): + x = base + i + y = base + i * 1e-10 # Very small slope relative to magnitude + checker.add_data_point(x, y) + + # Should handle without overflow/underflow + result = checker.check_slope(10.0) + assert result in [True, False] + + if checker.slope is not None: + assert not math.isnan(checker.slope) + assert not math.isinf(checker.slope) + + @pytest.mark.sanity + def test_detector_reset_clears_all_state(self): + """Test that detector reset completely clears state.""" + constraint = OverSaturationConstraint(minimum_duration=0.0, enabled=True) + + # Add data and trigger computation + for i in range(20): + constraint._add_started({"concurrent_requests": i, "duration": float(i)}) + constraint._add_finished({"ttft": i * 0.1, "duration": float(i)}) + + constraint._update_duration(25.0) + constraint._check_alert() # Populate computed values + + # Verify state exists + assert len(constraint.started_requests) > 0 + assert len(constraint.finished_requests) > 0 + assert constraint.total_started_ever > 0 + assert constraint.total_finished_ever > 0 + + # Reset + constraint.reset() + + # Verify complete reset + assert len(constraint.started_requests) == 0 + assert len(constraint.finished_requests) == 0 + assert constraint.total_started_ever == 0 + assert constraint.total_finished_ever == 0 + assert constraint.ttft_violations_counter == 0 + assert constraint.duration == 0.0 + + # Slope checkers should be reset too + assert constraint.concurrent_slope_checker.n == 0 + assert constraint.ttft_slope_checker.n == 0 + + @pytest.mark.sanity + @patch("time.time") + def test_constraint_time_calculation_accuracy(self, mock_time): + """Test that constraint calculates durations accurately.""" + # Mock time to control duration calculation + start_time = 1000.0 + current_time = 1030.0 # 30 seconds later + mock_time.return_value = current_time + + constraint = OverSaturationConstraint( + minimum_duration=25.0, enabled=True + ) # Should be met + + state = SchedulerState( + node_id=0, + num_processes=1, + start_time=start_time, + processing_requests=5, + ) + + request = RequestInfo( + request_id="test-request", + status="in_progress", + scheduler_node_id=0, + scheduler_process_id=0, + scheduler_start_time=start_time, + ) + + # Call constraint - should update detector duration + constraint(state, request) + + # Verify duration was calculated correctly + assert abs(constraint.duration - 30.0) < 0.001, ( + f"Expected duration ~30.0, got {constraint.duration}" + ) + + @pytest.mark.sanity + def test_ttft_violation_counting_accuracy(self): + """Test TTFT violation counting is accurate.""" + constraint = OverSaturationConstraint( + minimum_duration=0.0, + minimum_ttft=2.0, # Threshold + enabled=True, + ) + + # Add requests with known TTFT values + ttft_values = [1.0, 3.0, 1.5, 4.0, 2.1, 0.5, 5.0, 1.9] + expected_violations = sum( + 1 for ttft in ttft_values if ttft > 2.0 + ) # Should be 4 + + for i, ttft in enumerate(ttft_values): + constraint._add_finished({"ttft": ttft, "duration": float(i)}) + + assert constraint.ttft_violations_counter == expected_violations, ( + f"Expected {expected_violations} violations, " + f"got {constraint.ttft_violations_counter}" + )