Skip to content

Commit ccab11e

Browse files
fix: review suggestions
Signed-off-by: Alon Kellner <[email protected]>
1 parent 4c944de commit ccab11e

File tree

11 files changed

+156
-153
lines changed

11 files changed

+156
-153
lines changed

src/guidellm/__main__.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525

2626
import asyncio
2727
import codecs
28-
import json
2928
from pathlib import Path
3029

3130
import click
@@ -388,40 +387,26 @@ def benchmark():
388387
@click.option(
389388
"--over-saturation",
390389
"--detect-saturation", # alias
390+
"over_saturation",
391+
callback=cli_tools.parse_json,
391392
default=None,
392393
help=(
393394
"Enable over-saturation detection. "
394-
"Use --over-saturation=True for boolean flag, "
395-
"or a JSON dict with configuration "
395+
"Pass a JSON dict with configuration "
396396
'(e.g., \'{"enabled": true, "min_seconds": 30}\'). '
397397
"Defaults to None (disabled)."
398398
),
399-
type=click.UNPROCESSED,
399+
)
400+
@click.option(
401+
"--default-over-saturation",
402+
"over_saturation",
403+
flag_value={"enabled": True},
404+
help="Enable over-saturation detection with default settings.",
400405
)
401406
def run(**kwargs): # noqa: C901
402407
# Only set CLI args that differ from click defaults
403408
kwargs = cli_tools.set_if_not_default(click.get_current_context(), **kwargs)
404409

405-
# Handle over_saturation parsing (can be bool flag or JSON dict string)
406-
if "over_saturation" in kwargs and kwargs["over_saturation"] is not None:
407-
over_sat = kwargs["over_saturation"]
408-
if isinstance(over_sat, str):
409-
try:
410-
# Try parsing as JSON dict
411-
kwargs["over_saturation"] = json.loads(over_sat)
412-
except (json.JSONDecodeError, ValueError):
413-
# If not valid JSON, treat as bool flag
414-
kwargs["over_saturation"] = over_sat.lower() in (
415-
"true",
416-
"1",
417-
"yes",
418-
"on",
419-
)
420-
elif isinstance(over_sat, bool):
421-
# Already a bool, keep as is
422-
pass
423-
# If it's already a dict, keep as is
424-
425410
# Handle remapping for request params
426411
request_type = kwargs.pop("request_type", None)
427412
request_formatter_kwargs = kwargs.pop("request_formatter_kwargs", None)
@@ -557,8 +542,8 @@ def preprocess():
557542
"PreprocessDatasetConfig as JSON string, key=value pairs, "
558543
"or file path (.json, .yaml, .yml, .config). "
559544
"Example: 'prompt_tokens=100,output_tokens=50,prefix_tokens_max=10'"
560-
" or '{\"prompt_tokens\": 100, \"output_tokens\": 50, "
561-
"\"prefix_tokens_max\": 10}'"
545+
' or \'{"prompt_tokens": 100, "output_tokens": 50, '
546+
'"prefix_tokens_max": 10}\''
562547
),
563548
)
564549
@click.option(

src/guidellm/benchmark/entrypoints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ async def resolve_profile(
323323
max_errors: int | None,
324324
max_error_rate: float | None,
325325
max_global_error_rate: float | None,
326-
over_saturation: bool | dict[str, Any] | None = None,
326+
over_saturation: dict[str, Any] | None = None,
327327
console: Console | None = None,
328328
) -> Profile:
329329
"""
@@ -344,7 +344,7 @@ async def resolve_profile(
344344
:param max_errors: Maximum number of errors before stopping
345345
:param max_error_rate: Maximum error rate threshold before stopping
346346
:param max_global_error_rate: Maximum global error rate threshold before stopping
347-
:param over_saturation: Over-saturation detection configuration (bool or dict)
347+
:param over_saturation: Over-saturation detection configuration (dict)
348348
:param console: Console instance for progress reporting, or None
349349
:return: Configured Profile instance ready for benchmarking
350350
:raises ValueError: If constraints are provided with a pre-configured Profile

src/guidellm/benchmark/schemas/generative/entrypoints.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,12 +283,12 @@ def get_default(cls: type[BenchmarkGenerativeTextArgs], field: str) -> Any:
283283
max_global_error_rate: float | None = Field(
284284
default=None, description="Maximum global error rate (0-1) before stopping"
285285
)
286-
over_saturation: bool | dict[str, Any] | None = Field(
286+
over_saturation: dict[str, Any] | None = Field(
287287
default=None,
288288
description=(
289-
"Over-saturation detection configuration. Can be a bool to enable/disable "
290-
"with defaults, or a dict with configuration parameters (enabled, "
291-
"min_seconds, max_window_seconds, moe_threshold, etc.)."
289+
"Over-saturation detection configuration. A dict with configuration "
290+
"parameters (enabled, min_seconds, max_window_seconds, "
291+
"moe_threshold, etc.)."
292292
),
293293
)
294294

src/guidellm/scheduler/constraints/factory.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,13 @@
1010

1111
from typing import Any
1212

13-
from guidellm.utils import InfoMixin, RegistryMixin
14-
15-
from .constraint import (
13+
from guidellm.scheduler.constraints.constraint import (
1614
Constraint,
1715
ConstraintInitializer,
1816
SerializableConstraintInitializer,
1917
UnserializableConstraintInitializer,
2018
)
19+
from guidellm.utils import InfoMixin, RegistryMixin
2120

2221
__all__ = ["ConstraintsInitializerFactory"]
2322

src/guidellm/scheduler/constraints/saturation.py

Lines changed: 59 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,16 @@
5757

5858
from pydantic import Field
5959

60+
from guidellm.scheduler.constraints.constraint import (
61+
Constraint,
62+
PydanticConstraintInitializer,
63+
)
64+
from guidellm.scheduler.constraints.factory import ConstraintsInitializerFactory
6065
from guidellm.scheduler.schemas import (
6166
SchedulerState,
6267
SchedulerUpdateAction,
6368
)
6469
from guidellm.schemas import RequestInfo
65-
from guidellm.settings import settings
66-
67-
from .constraint import Constraint, PydanticConstraintInitializer
68-
from .factory import ConstraintsInitializerFactory
6970

7071
__all__ = [
7172
"OverSaturationConstraint",
@@ -355,7 +356,12 @@ def reset(self) -> None:
355356
)
356357

357358
def _add_finished(self, request: dict[str, Any]) -> None:
358-
"""Add a finished request to tracking."""
359+
"""
360+
Add a finished request to tracking.
361+
362+
:param request: Dictionary containing request data with 'ttft' and
363+
'duration' keys.
364+
"""
359365
ttft = request["ttft"]
360366
duration = request["duration"]
361367
if ttft is not None:
@@ -366,7 +372,12 @@ def _add_finished(self, request: dict[str, Any]) -> None:
366372
self.ttft_slope_checker.add_data_point(duration, ttft)
367373

368374
def _remove_finished(self, request: dict[str, Any]) -> None:
369-
"""Remove a finished request from tracking."""
375+
"""
376+
Remove a finished request from tracking.
377+
378+
:param request: Dictionary containing request data with 'ttft' and
379+
'duration' keys.
380+
"""
370381
del self.finished_requests[0]
371382
ttft = request["ttft"]
372383
duration = request["duration"]
@@ -375,7 +386,12 @@ def _remove_finished(self, request: dict[str, Any]) -> None:
375386
self.ttft_slope_checker.remove_data_point(duration, ttft)
376387

377388
def _add_started(self, request: dict[str, Any]) -> None:
378-
"""Add a started request to tracking."""
389+
"""
390+
Add a started request to tracking.
391+
392+
:param request: Dictionary containing request data with
393+
'concurrent_requests' and 'duration' keys.
394+
"""
379395
concurrent = request["concurrent_requests"]
380396
duration = request["duration"]
381397
if concurrent is not None:
@@ -384,14 +400,26 @@ def _add_started(self, request: dict[str, Any]) -> None:
384400
self.concurrent_slope_checker.add_data_point(duration, concurrent)
385401

386402
def _remove_started(self, request: dict[str, Any]) -> None:
387-
"""Remove a started request from tracking."""
403+
"""
404+
Remove a started request from tracking.
405+
406+
:param request: Dictionary containing request data with
407+
'concurrent_requests' and 'duration' keys.
408+
"""
388409
del self.started_requests[0]
389410
concurrent = request["concurrent_requests"]
390411
duration = request["duration"]
391412
self.concurrent_slope_checker.remove_data_point(duration, concurrent)
392413

393414
def _update_duration(self, duration: float) -> None:
394-
"""Update duration and prune old data points."""
415+
"""
416+
Update duration and prune old data points.
417+
418+
Updates the current duration and removes data points that exceed the maximum
419+
window size (by ratio or time) to maintain bounded memory usage.
420+
421+
:param duration: Current duration in seconds since benchmark start.
422+
"""
395423
self.duration = duration
396424

397425
maximum_finished_window_size = int(
@@ -428,8 +456,7 @@ def _check_alert(self) -> bool:
428456
"""
429457
Check if over-saturation is currently detected.
430458
431-
Returns:
432-
True if over-saturation is detected, False otherwise.
459+
:return: True if over-saturation is detected, False otherwise.
433460
"""
434461
# Use duration as the maximum n value since requests from the
435462
# same second are highly correlated, this is simple and good enough
@@ -521,13 +548,13 @@ class OverSaturationConstraintInitializer(PydanticConstraintInitializer):
521548
Factory for creating OverSaturationConstraint instances from configuration.
522549
523550
Provides a Pydantic-based initializer for over-saturation detection constraints
524-
with support for flexible configuration patterns. Supports both simple boolean
525-
flags and detailed configuration dictionaries, enabling easy integration with
526-
CLI arguments, configuration files, and programmatic constraint creation.
551+
with support for flexible configuration patterns. Supports detailed configuration
552+
dictionaries, enabling easy integration with CLI arguments, configuration files,
553+
and programmatic constraint creation.
527554
528555
Example:
529556
::
530-
# Simple boolean configuration
557+
# Configuration with defaults
531558
initializer = OverSaturationConstraintInitializer(enabled=True)
532559
constraint = initializer.create_constraint()
533560
@@ -618,18 +645,18 @@ def create_constraint(self, **_kwargs) -> Constraint:
618645

619646
@classmethod
620647
def validated_kwargs(
621-
cls, over_saturation: bool | dict[str, Any] | None = None, **kwargs
648+
cls, over_saturation: dict[str, Any] | None = None, **kwargs
622649
) -> dict[str, Any]:
623650
"""
624651
Validate and process arguments for OverSaturationConstraint creation.
625652
626-
Processes flexible input formats to create validated constraint configuration.
627-
Supports boolean flags for simple enable/disable, dictionary inputs for detailed
628-
configuration, and alias parameters for compatibility. Handles parameter
629-
normalization and default value application.
653+
Processes flexible input formats to create validated constraint
654+
configuration. Supports dictionary inputs for detailed configuration, and
655+
alias parameters for compatibility. Handles parameter normalization and
656+
default value application.
630657
631-
:param over_saturation: Boolean to enable/disable with defaults, or dictionary
632-
with configuration parameters (min_seconds, max_window_seconds, etc.)
658+
:param over_saturation: Dictionary with configuration parameters
659+
(min_seconds, max_window_seconds, etc.)
633660
:param kwargs: Additional keyword arguments supporting aliases like
634661
"detect_saturation" for compatibility, or unpacked dict values when
635662
dict is passed to factory
@@ -638,7 +665,7 @@ def validated_kwargs(
638665
"""
639666
# Check for aliases in kwargs
640667
aliases = ["over_saturation", "detect_saturation"]
641-
result: bool | dict[str, Any] | None = over_saturation
668+
result: dict[str, Any] | None = over_saturation
642669

643670
for alias in aliases:
644671
alias_value = kwargs.get(alias)
@@ -664,37 +691,13 @@ def validated_kwargs(
664691
result = {key: kwargs[key] for key in constraint_keys if key in kwargs}
665692

666693
if result is None:
667-
return {}
668-
669-
if isinstance(result, bool):
670-
# When a boolean is passed, read defaults from settings
671-
return {
672-
"enabled": result,
673-
"min_seconds": kwargs.get(
674-
"min_seconds", settings.constraint_over_saturation_min_seconds
675-
),
676-
"max_window_seconds": kwargs.get(
677-
"max_window_seconds",
678-
settings.constraint_over_saturation_max_window_seconds,
679-
),
680-
}
681-
elif isinstance(result, dict):
682-
# Extract configuration from dict, reading from settings for missing values
683-
return {
684-
"enabled": result.get("enabled", True),
685-
"min_seconds": result.get(
686-
"min_seconds", settings.constraint_over_saturation_min_seconds
687-
),
688-
"max_window_seconds": result.get(
689-
"max_window_seconds",
690-
settings.constraint_over_saturation_max_window_seconds,
691-
),
692-
"moe_threshold": result.get("moe_threshold", 2.0),
693-
"minimum_ttft": result.get("minimum_ttft", 2.5),
694-
"maximum_window_ratio": result.get("maximum_window_ratio", 0.75),
695-
"minimum_window_size": result.get("minimum_window_size", 5),
696-
"confidence": result.get("confidence", 0.95),
697-
}
694+
return {"enabled": False}
695+
696+
if isinstance(result, dict):
697+
# Return dict as-is, defaults come from fields above
698+
return result
698699
else:
699-
# Convert to bool if it's truthy
700-
return {"enabled": bool(result)}
700+
# Type signature only accepts dict or None, so this should never happen
701+
raise TypeError(
702+
f"over_saturation must be a dict or None, got {type(result).__name__}"
703+
)

src/guidellm/utils/cli.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,36 @@ def parse_list_floats(ctx, param, value):
6565
) from err
6666

6767

68-
def parse_json(ctx, param, value): # noqa: ARG001
68+
def parse_json(ctx, param, value): # noqa: ARG001, C901, PLR0911, PLR0912
6969
if value is None or value == [None]:
7070
return None
71+
if isinstance(value, dict | list):
72+
# Already parsed (e.g., from flag_value), return as-is
73+
return value
7174
if isinstance(value, list | tuple):
7275
return [parse_json(ctx, param, val) for val in value]
7376

77+
# Handle empty strings (can occur when multiple options map to same parameter)
78+
if isinstance(value, str) and not value.strip():
79+
return None
80+
81+
# Handle string representation of dict (can occur when flag_value dict is
82+
# converted to string)
83+
if isinstance(value, str) and value.startswith("{") and value.endswith("}"):
84+
# Try to parse as JSON first
85+
try:
86+
return json.loads(value)
87+
except json.JSONDecodeError:
88+
# If JSON parsing fails, try ast.literal_eval for Python dict syntax
89+
try:
90+
import ast
91+
92+
parsed = ast.literal_eval(value)
93+
if isinstance(parsed, dict):
94+
return parsed
95+
except (ValueError, SyntaxError):
96+
pass # Fall through to normal processing
97+
7498
if "{" not in value and "}" not in value and "=" in value:
7599
# Treat it as a key=value pair if it doesn't look like JSON.
76100
result = {}

tests/e2e/test_over_saturated_benchmark.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def server():
3636
@pytest.mark.timeout(60)
3737
def test_over_saturated_benchmark(server: VllmSimServer):
3838
"""
39-
Another example test interacting with the server.
39+
Test over-saturation detection using the --default-over-saturation flag.
4040
"""
4141
report_path = Path("tests/e2e/over_saturated_benchmarks.json")
4242
rate = 10
@@ -45,14 +45,15 @@ def test_over_saturated_benchmark(server: VllmSimServer):
4545
client = GuidellmClient(target=server.get_url(), output_path=report_path)
4646

4747
cleanup_report_file(report_path)
48-
# Start the benchmark
48+
# Start the benchmark with --default-over-saturation flag
4949
client.start_benchmark(
5050
rate=rate,
5151
max_seconds=20,
52-
over_saturation=True,
52+
over_saturation={}, # Empty dict triggers --default-over-saturation flag
5353
extra_env={
54-
"GUIDELLM__CONSTRAINT_OVER_SATURATION_MIN_SECONDS": "0",
5554
"GOMAXPROCS": "1",
55+
# Set min_seconds via env var for faster test
56+
"GUIDELLM__CONSTRAINT_OVER_SATURATION_MIN_SECONDS": "0",
5657
},
5758
)
5859

0 commit comments

Comments
 (0)