Skip to content

Commit aa9b2f6

Browse files
committed
Fixes linter errors.
Signed-off-by: rlratzel <[email protected]>
1 parent 13eaca6 commit aa9b2f6

File tree

16 files changed

+189
-191
lines changed

16 files changed

+189
-191
lines changed

benchmarking/Dockerfile

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ ENV MAMBA_ROOT_PREFIX=/opt/micromamba
4848
ENV PATH=$MAMBA_ROOT_PREFIX/bin:$PATH
4949
RUN curl -Ls https://micromamba.snakepit.net/api/micromamba/linux-64/latest | tar -xvj -C /usr/local/bin --strip-components=1 bin/micromamba && \
5050
micromamba shell init -s bash -r $MAMBA_ROOT_PREFIX
51-
51+
5252
# Install uv
5353
ENV UV_VERSION="0.8.22"
5454
RUN curl -LsSf https://astral.sh/uv/${UV_VERSION}/install.sh | sh
@@ -93,7 +93,7 @@ RUN git clone https://github.com/OpenGVLab/InternVideo.git && \
9393

9494

9595
########################################################################
96-
# curator_benchmark image -
96+
# curator_benchmark image -
9797
#
9898
# use cases:
9999
# * Start a container standalone to run all Curator benchmarks. Datasets are downloaded automatically and reside only in the container.
@@ -118,6 +118,7 @@ FROM curator_system_base AS curator_benchmarking
118118
COPY --from=curator_setup_deps /opt /opt
119119

120120
# Install Curator, which includes benchmarking tools
121+
# Update pyproject.toml to get the latest RAPIDS libs
121122
COPY . /opt/Curator
122123
RUN cd /opt/Curator \
123124
&& uv sync --link-mode copy --locked --extra all --all-groups \
@@ -133,7 +134,7 @@ ARG NVIDIA_BUILD_REF
133134
LABEL com.nvidia.build.ref="${NVIDIA_BUILD_REF}"
134135

135136
# Install deps for specific benchmark scripts.
136-
# FIXME: look into a way that script authors can install their own deps so this does not need to be updated for each new script dep.
137+
# TODO: look into a way that script authors can install their own deps so this does not need to be updated for each new script dep.
137138
RUN apt-get install -y --no-install-recommends \
138139
wget
139140

benchmarking/README.md

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -322,9 +322,9 @@ def main():
322322
# Add your custom arguments
323323
parser.add_argument("--input", type=str)
324324
parser.add_argument("--iterations", type=int, default=100)
325-
325+
326326
args = parser.parse_args()
327-
327+
328328
# Your benchmark logic here
329329
run_benchmark(args)
330330
@@ -407,28 +407,28 @@ from nemo_curator.tasks.utils import TaskPerfUtils
407407
def run_benchmark(args):
408408
"""Main benchmark logic."""
409409
start_time = time.time()
410-
410+
411411
# Your benchmark code here
412412
with Task("my_operation", TaskPerfUtils()):
413413
result = perform_operation(args.input)
414-
414+
415415
execution_time = time.time() - start_time
416-
416+
417417
# Write required output files
418418
params = {
419419
"input": str(args.input),
420420
"parameter1": args.param1,
421421
}
422422
with open(args.benchmark_results_path / "params.json", "w") as f:
423423
json.dump(params, f, indent=2)
424-
424+
425425
metrics = {
426426
"execution_time_s": execution_time,
427427
"items_processed": len(result),
428428
}
429429
with open(args.benchmark_results_path / "metrics.json", "w") as f:
430430
json.dump(metrics, f, indent=2)
431-
431+
432432
tasks = Task.get_all_tasks()
433433
with open(args.benchmark_results_path / "tasks.pkl", "wb") as f:
434434
pickle.dump(tasks, f)
@@ -439,7 +439,7 @@ def main():
439439
parser.add_argument("--benchmark-results-path", type=Path, required=True)
440440
parser.add_argument("--input", type=str, required=True)
441441
parser.add_argument("--param1", type=str, default="default")
442-
442+
443443
args = parser.parse_args()
444444
run_benchmark(args)
445445
@@ -521,33 +521,33 @@ class MyCustomSink(Sink):
521521
self.config = config
522522
self.enabled = config.get("enabled", True)
523523
self.api_endpoint = config.get("api_endpoint")
524-
524+
525525
# Initialize any resources
526526
if not self.api_endpoint:
527527
raise ValueError("MyCustomSink: api_endpoint is required")
528-
528+
529529
def initialize(self, session_name: str, env_data: dict[str, Any]) -> None:
530530
"""Called at session start."""
531531
self.session_name = session_name
532532
self.env_data = env_data
533-
533+
534534
if self.enabled:
535535
logger.info(f"MyCustomSink: Starting session {session_name}")
536536
# Perform initialization (e.g., create remote session)
537-
537+
538538
def process_result(self, result: dict[str, Any]) -> None:
539539
"""Called after each entry completes."""
540540
if self.enabled:
541541
logger.info(f"MyCustomSink: Processing {result['name']}")
542542
# Send result to your API, database, etc.
543543
self._send_to_api(result)
544-
544+
545545
def finalize(self) -> None:
546546
"""Called at session end."""
547547
if self.enabled:
548548
logger.info("MyCustomSink: Finalizing session")
549549
# Perform cleanup, send summary, etc.
550-
550+
551551
def _send_to_api(self, data: dict) -> None:
552552
"""Helper method for API calls."""
553553
# Your implementation
@@ -895,7 +895,7 @@ entries:
895895
- name: benchmark_v1
896896
script: my_benchmark.py
897897
args: --input {dataset:sample_data,parquet} --algorithm v1
898-
898+
899899
- name: benchmark_v2
900900
script: my_benchmark.py
901901
args: --input {dataset:sample_data,parquet} --algorithm v2
@@ -1009,4 +1009,3 @@ benchmarking/
10091009
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
10101010

10111011
Licensed under the Apache License, Version 2.0. See the main repository LICENSE file for details.
1012-

benchmarking/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-

benchmarking/commands.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ python ./scripts/common_crawl_benchmark.py \
2727
--url_limit 10 \
2828
--add_filename_column \
2929
--executor ray_data \
30-
--ray_data_cast_as_actor
30+
--ray_data_cast_as_actor

benchmarking/run.py

Lines changed: 49 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,66 +13,40 @@
1313
# limitations under the License.
1414

1515
import argparse
16-
from pathlib import Path
17-
import yaml
18-
import os
19-
import time
20-
import sys
21-
from typing import Any
2216
import json
23-
import traceback
24-
from statistics import mean, stdev
17+
import os
2518
import pickle
2619
import shutil
20+
import sys
21+
import time
2722
import traceback
23+
from pathlib import Path
24+
from typing import Any
2825

26+
import yaml
2927
from loguru import logger
3028

31-
from nemo_curator.tasks import Task
3229
from nemo_curator.tasks.utils import TaskPerfUtils
30+
from nemo_curator.utils.file_utils import create_or_overwrite_dir
3331

34-
# FIXME: How do we want to package this tool? Perhaps a package extra for
32+
# TODO: How do we want to package this tool? Perhaps a package extra for
3533
# nemo-curator, i.e. nemo-curator[benchmarking]?
3634
# For now, add this directory to PYTHONPATH to import the runner modules
3735
sys.path.insert(0, Path(__file__).parent)
38-
from runner.matrix import MatrixConfig, MatrixEntry
3936
from runner.datasets import DatasetResolver
40-
from runner.utils import get_obj_for_json
41-
from runner.process import run_command_with_timeout
4237
from runner.env_capture import dump_env
38+
from runner.matrix import MatrixConfig, MatrixEntry
39+
from runner.process import run_command_with_timeout
40+
from runner.utils import get_obj_for_json
4341

4442

4543
def ensure_dir(dir_path: Path) -> None:
4644
"""Ensure dir_path and parents exists, creating them if necessary."""
4745
dir_path.mkdir(parents=True, exist_ok=True)
4846

4947

50-
def create_or_overwrite_dir(dir_path: Path) -> None:
51-
"""Create directory, removing it if it exists."""
52-
if dir_path.exists():
53-
shutil.rmtree(dir_path, ignore_errors=True)
54-
dir_path.mkdir(parents=True, exist_ok=True)
55-
56-
57-
def aggregate_task_metrics(tasks: list[Task], prefix: str | None = None) -> dict[str, Any]:
58-
"""Aggregate task metrics by computing mean/std/sum."""
59-
metrics = {}
60-
tasks_metrics = TaskPerfUtils.collect_stage_metrics(tasks)
61-
# For each of the metric compute mean/std/sum and flatten the dict
62-
for stage_name, stage_data in tasks_metrics.items():
63-
for metric_name, values in stage_data.items():
64-
for agg_name, agg_func in [("sum", sum), ("mean", mean), ("std", stdev)]:
65-
stage_key = stage_name if prefix is None else f"{prefix}_{stage_name}"
66-
if len(values) > 0:
67-
metrics[f"{stage_key}_{metric_name}_{agg_name}"] = float(agg_func(values))
68-
else:
69-
metrics[f"{stage_key}_{metric_name}_{agg_name}"] = 0.0
70-
return metrics
71-
72-
7348
def get_entry_script_persisted_data(benchmark_results_path: Path) -> dict[str, Any]:
74-
""" Read the files that are expected to be generated by the individual benchmark scripts.
75-
"""
49+
"""Read the files that are expected to be generated by the individual benchmark scripts."""
7650
params_json = benchmark_results_path / "params.json"
7751
if not params_json.exists():
7852
logger.warning(f"Params JSON file not found at {params_json}")
@@ -97,22 +71,23 @@ def get_entry_script_persisted_data(benchmark_results_path: Path) -> dict[str, A
9771
with open(tasks_pkl, "rb") as f:
9872
script_tasks = pickle.load(f) # noqa: S301
9973
if isinstance(script_tasks, list):
100-
script_metrics.update(aggregate_task_metrics(script_tasks, prefix="task"))
74+
script_metrics.update(TaskPerfUtils.aggregate_task_metrics(script_tasks, prefix="task"))
10175
elif isinstance(script_tasks, dict):
10276
for pipeline_name, pipeline_tasks in script_tasks.items():
103-
script_metrics.update(aggregate_task_metrics(pipeline_tasks, prefix=pipeline_name.lower()))
77+
script_metrics.update(
78+
TaskPerfUtils.aggregate_task_metrics(pipeline_tasks, prefix=pipeline_name.lower())
79+
)
10480

10581
return {"params": script_params, "metrics": script_metrics}
10682

10783

108-
def run_entry( # noqa: PLR0915
84+
def run_entry(
10985
entry: MatrixEntry,
11086
dataset_resolver: DatasetResolver,
11187
session_path: Path,
11288
result: dict[str, Any],
11389
) -> tuple[dict[str, Any], bool, dict[str, Any]]:
114-
115-
started_at = time.time()
90+
started_at = time.time()
11691
session_entry_path = session_path / entry.name
11792

11893
# scratch_path : This is the directory user can use to store scratch data; it'll be cleaned up after the entry is done
@@ -155,23 +130,27 @@ def run_entry( # noqa: PLR0915
155130
logger.warning(f"\t\t⏰ Timed out after {entry.timeout_s}s")
156131
logger.info(f"\t\tLogs found in {logs_path}")
157132

158-
result.update({
159-
"cmd": cmd,
160-
"started_at": started_at,
161-
"ended_at": time.time(),
162-
"exec_started_at": started_exec,
163-
"exec_time_s": ended_exec - started_exec,
164-
"exit_code": completed["returncode"],
165-
"timed_out": completed["timed_out"],
166-
"logs_dir": logs_path,
167-
"success": success,
168-
})
133+
result.update(
134+
{
135+
"cmd": cmd,
136+
"started_at": started_at,
137+
"ended_at": time.time(),
138+
"exec_started_at": started_exec,
139+
"exec_time_s": ended_exec - started_exec,
140+
"exit_code": completed["returncode"],
141+
"timed_out": completed["timed_out"],
142+
"logs_dir": logs_path,
143+
"success": success,
144+
}
145+
)
169146
ray_data = {}
170147
script_persisted_data = get_entry_script_persisted_data(benchmark_results_path)
171-
result.update({
172-
"ray_data": ray_data,
173-
"script_persisted_data": script_persisted_data,
174-
})
148+
result.update(
149+
{
150+
"ray_data": ray_data,
151+
"script_persisted_data": script_persisted_data,
152+
}
153+
)
175154
Path(session_entry_path / "results.json").write_text(json.dumps(get_obj_for_json(result)))
176155

177156
return success
@@ -200,10 +179,11 @@ def main() -> None:
200179
# and use by passing individual components the keys they need
201180
config_dict = {}
202181
for yml_file in args.config:
203-
config_dicts = yaml.full_load_all(open(yml_file))
182+
with open(yml_file) as f:
183+
config_dicts = yaml.full_load_all(f)
204184
for d in config_dicts:
205185
config_dict.update(d)
206-
186+
207187
config = MatrixConfig.create_from_dict(config_dict)
208188
resolver = DatasetResolver.create_from_dicts(config_dict["datasets"])
209189

@@ -216,7 +196,7 @@ def main() -> None:
216196
session_overall_success = True
217197
logger.info(f"Started session {session_name}...")
218198
env_data = dump_env(session_path)
219-
199+
220200
for sink in config.sinks:
221201
sink.initialize(session_name, env_data)
222202

@@ -242,12 +222,14 @@ def main() -> None:
242222
error_traceback = traceback.format_exc()
243223
logger.error(f"\t\t❌ Entry failed with exception: {e}")
244224
logger.debug(f"Full traceback:\n{error_traceback}")
245-
result.update({
246-
"error": str(e),
247-
"traceback": error_traceback,
248-
"success": run_success,
249-
})
250-
225+
result.update(
226+
{
227+
"error": str(e),
228+
"traceback": error_traceback,
229+
"success": run_success,
230+
}
231+
)
232+
251233
finally:
252234
session_overall_success &= run_success
253235
for sink in config.sinks:

benchmarking/runner/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +0,0 @@
1-

benchmarking/runner/datasets.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,22 @@ def create_from_dicts(cls, data: list[dict]) -> DatasetResolver:
3636
# Check for duplicate dataset names before proceeding
3737
names = [d["name"] for d in data]
3838
if len(names) != len(set(names)):
39-
duplicates = set([name for name in names if names.count(name) > 1])
40-
raise ValueError(f"Duplicate dataset name(s) found: {', '.join(duplicates)}")
39+
duplicates = {name for name in names if names.count(name) > 1}
40+
msg = f"Duplicate dataset name(s) found: {', '.join(duplicates)}"
41+
raise ValueError(msg)
4142

4243
instance = cls()
4344
for dataset in data:
4445
formats = dataset["formats"]
45-
assert isinstance(formats, list), "formats must be a list"
46+
if not isinstance(formats, list):
47+
msg = "formats must be a list"
48+
raise TypeError(msg)
4649
format_map = {}
4750
for fmt in formats:
4851
format_map[fmt["type"]] = fmt["path"]
4952
instance._map[dataset["name"]] = format_map
5053
return instance
5154

52-
5355
def resolve(self, dataset_name: str, file_format: str) -> str:
5456
if dataset_name not in self._map:
5557
msg = f"Unknown dataset: {dataset_name}"

benchmarking/runner/env_capture.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from typing import Any
2222

2323
from loguru import logger
24-
2524
from runner.utils import get_obj_for_json
2625

2726

@@ -68,4 +67,4 @@ def get_env() -> dict[str, Any]:
6867
"python_version": platform.python_version(),
6968
"executable": os.getenv("_"),
7069
"cuda_visible_devices": os.getenv("CUDA_VISIBLE_DEVICES", ""),
71-
}
70+
}

0 commit comments

Comments
 (0)