1313# limitations under the License.
1414
1515import argparse
16- from pathlib import Path
17- import yaml
18- import os
19- import time
20- import sys
21- from typing import Any
2216import json
23- import traceback
24- from statistics import mean , stdev
17+ import os
2518import pickle
2619import shutil
20+ import sys
21+ import time
2722import traceback
23+ from pathlib import Path
24+ from typing import Any
2825
26+ import yaml
2927from loguru import logger
3028
31- from nemo_curator .tasks import Task
3229from nemo_curator .tasks .utils import TaskPerfUtils
30+ from nemo_curator .utils .file_utils import create_or_overwrite_dir
3331
34- # FIXME : How do we want to package this tool? Perhaps a package extra for
32+ # TODO : How do we want to package this tool? Perhaps a package extra for
3533# nemo-curator, i.e. nemo-curator[benchmarking]?
3634# For now, add this directory to PYTHONPATH to import the runner modules
3735sys .path .insert (0 , Path (__file__ ).parent )
38- from runner .matrix import MatrixConfig , MatrixEntry
3936from runner .datasets import DatasetResolver
40- from runner .utils import get_obj_for_json
41- from runner .process import run_command_with_timeout
4237from runner .env_capture import dump_env
38+ from runner .matrix import MatrixConfig , MatrixEntry
39+ from runner .process import run_command_with_timeout
40+ from runner .utils import get_obj_for_json
4341
4442
4543def ensure_dir (dir_path : Path ) -> None :
4644 """Ensure dir_path and parents exists, creating them if necessary."""
4745 dir_path .mkdir (parents = True , exist_ok = True )
4846
4947
50- def create_or_overwrite_dir (dir_path : Path ) -> None :
51- """Create directory, removing it if it exists."""
52- if dir_path .exists ():
53- shutil .rmtree (dir_path , ignore_errors = True )
54- dir_path .mkdir (parents = True , exist_ok = True )
55-
56-
57- def aggregate_task_metrics (tasks : list [Task ], prefix : str | None = None ) -> dict [str , Any ]:
58- """Aggregate task metrics by computing mean/std/sum."""
59- metrics = {}
60- tasks_metrics = TaskPerfUtils .collect_stage_metrics (tasks )
61- # For each of the metric compute mean/std/sum and flatten the dict
62- for stage_name , stage_data in tasks_metrics .items ():
63- for metric_name , values in stage_data .items ():
64- for agg_name , agg_func in [("sum" , sum ), ("mean" , mean ), ("std" , stdev )]:
65- stage_key = stage_name if prefix is None else f"{ prefix } _{ stage_name } "
66- if len (values ) > 0 :
67- metrics [f"{ stage_key } _{ metric_name } _{ agg_name } " ] = float (agg_func (values ))
68- else :
69- metrics [f"{ stage_key } _{ metric_name } _{ agg_name } " ] = 0.0
70- return metrics
71-
72-
7348def get_entry_script_persisted_data (benchmark_results_path : Path ) -> dict [str , Any ]:
74- """ Read the files that are expected to be generated by the individual benchmark scripts.
75- """
49+ """Read the files that are expected to be generated by the individual benchmark scripts."""
7650 params_json = benchmark_results_path / "params.json"
7751 if not params_json .exists ():
7852 logger .warning (f"Params JSON file not found at { params_json } " )
@@ -97,22 +71,23 @@ def get_entry_script_persisted_data(benchmark_results_path: Path) -> dict[str, A
9771 with open (tasks_pkl , "rb" ) as f :
9872 script_tasks = pickle .load (f ) # noqa: S301
9973 if isinstance (script_tasks , list ):
100- script_metrics .update (aggregate_task_metrics (script_tasks , prefix = "task" ))
74+ script_metrics .update (TaskPerfUtils . aggregate_task_metrics (script_tasks , prefix = "task" ))
10175 elif isinstance (script_tasks , dict ):
10276 for pipeline_name , pipeline_tasks in script_tasks .items ():
103- script_metrics .update (aggregate_task_metrics (pipeline_tasks , prefix = pipeline_name .lower ()))
77+ script_metrics .update (
78+ TaskPerfUtils .aggregate_task_metrics (pipeline_tasks , prefix = pipeline_name .lower ())
79+ )
10480
10581 return {"params" : script_params , "metrics" : script_metrics }
10682
10783
108- def run_entry ( # noqa: PLR0915
84+ def run_entry (
10985 entry : MatrixEntry ,
11086 dataset_resolver : DatasetResolver ,
11187 session_path : Path ,
11288 result : dict [str , Any ],
11389) -> tuple [dict [str , Any ], bool , dict [str , Any ]]:
114-
115- started_at = time .time ()
90+ started_at = time .time ()
11691 session_entry_path = session_path / entry .name
11792
11893 # scratch_path : This is the directory user can use to store scratch data; it'll be cleaned up after the entry is done
@@ -155,23 +130,27 @@ def run_entry( # noqa: PLR0915
155130 logger .warning (f"\t \t ⏰ Timed out after { entry .timeout_s } s" )
156131 logger .info (f"\t \t Logs found in { logs_path } " )
157132
158- result .update ({
159- "cmd" : cmd ,
160- "started_at" : started_at ,
161- "ended_at" : time .time (),
162- "exec_started_at" : started_exec ,
163- "exec_time_s" : ended_exec - started_exec ,
164- "exit_code" : completed ["returncode" ],
165- "timed_out" : completed ["timed_out" ],
166- "logs_dir" : logs_path ,
167- "success" : success ,
168- })
133+ result .update (
134+ {
135+ "cmd" : cmd ,
136+ "started_at" : started_at ,
137+ "ended_at" : time .time (),
138+ "exec_started_at" : started_exec ,
139+ "exec_time_s" : ended_exec - started_exec ,
140+ "exit_code" : completed ["returncode" ],
141+ "timed_out" : completed ["timed_out" ],
142+ "logs_dir" : logs_path ,
143+ "success" : success ,
144+ }
145+ )
169146 ray_data = {}
170147 script_persisted_data = get_entry_script_persisted_data (benchmark_results_path )
171- result .update ({
172- "ray_data" : ray_data ,
173- "script_persisted_data" : script_persisted_data ,
174- })
148+ result .update (
149+ {
150+ "ray_data" : ray_data ,
151+ "script_persisted_data" : script_persisted_data ,
152+ }
153+ )
175154 Path (session_entry_path / "results.json" ).write_text (json .dumps (get_obj_for_json (result )))
176155
177156 return success
@@ -200,10 +179,11 @@ def main() -> None:
200179 # and use by passing individual components the keys they need
201180 config_dict = {}
202181 for yml_file in args .config :
203- config_dicts = yaml .full_load_all (open (yml_file ))
182+ with open (yml_file ) as f :
183+ config_dicts = yaml .full_load_all (f )
204184 for d in config_dicts :
205185 config_dict .update (d )
206-
186+
207187 config = MatrixConfig .create_from_dict (config_dict )
208188 resolver = DatasetResolver .create_from_dicts (config_dict ["datasets" ])
209189
@@ -216,7 +196,7 @@ def main() -> None:
216196 session_overall_success = True
217197 logger .info (f"Started session { session_name } ..." )
218198 env_data = dump_env (session_path )
219-
199+
220200 for sink in config .sinks :
221201 sink .initialize (session_name , env_data )
222202
@@ -242,12 +222,14 @@ def main() -> None:
242222 error_traceback = traceback .format_exc ()
243223 logger .error (f"\t \t ❌ Entry failed with exception: { e } " )
244224 logger .debug (f"Full traceback:\n { error_traceback } " )
245- result .update ({
246- "error" : str (e ),
247- "traceback" : error_traceback ,
248- "success" : run_success ,
249- })
250-
225+ result .update (
226+ {
227+ "error" : str (e ),
228+ "traceback" : error_traceback ,
229+ "success" : run_success ,
230+ }
231+ )
232+
251233 finally :
252234 session_overall_success &= run_success
253235 for sink in config .sinks :
0 commit comments