88import os
99import wandb
1010
11- flags .DEFINE_string ('experiment_dir' , None , 'Path to experiment dir.' )
12- flags .DEFINE_string ('workload ' , None , 'Filter only for workload. If None include all workloads in experiment.' )
13- flags .DEFINE_string ('project_name' , 'visulaize-training-curves' , 'Wandb project name.' )
11+ flags .DEFINE_string ('experiment_dir' , '/home/kasimbeg/algoperf-runs-internal/experiments/pmap_ref' , 'Path to experiment dir.' )
12+ flags .DEFINE_string ('workloads ' , 'librispeech_conformer_jax' , 'Filter only for workload. If None include all workloads in experiment.' )
13+ flags .DEFINE_string ('project_name' , 'visulaize-training-curves-pmap ' , 'Wandb project name.' )
1414flags .DEFINE_string ('run_postfix' , '' , 'Postfix for wandb runs.' )
1515
1616FLAGS = flags .FLAGS
@@ -27,12 +27,15 @@ def main(_):
2727 experiment_dir = FLAGS .experiment_dir
2828 study_dirs = os .listdir (experiment_dir )
2929 for study_dir in study_dirs :
30- workload_dirs = os .listdir (os .path .join (experiment_dir , study_dir ))
31- workload_dirs = [
32- w for w in workload_dirs
33- if os .path .isdir (os .path .join (experiment_dir , study_dir , w ))
34- ]
35- print (workload_dirs )
30+ if not FLAGS .workloads :
31+ workload_dirs = os .listdir (os .path .join (experiment_dir , study_dir ))
32+ workload_dirs = [
33+ w for w in workload_dirs
34+ if os .path .isdir (os .path .join (experiment_dir , study_dir , w ))
35+ ]
36+ print (workload_dirs )
37+ else :
38+ workload_dirs = FLAGS .workloads .split (',' )
3639 for workload in workload_dirs :
3740 data = {
3841 'workload' : workload ,
@@ -44,9 +47,14 @@ def main(_):
4447 if re .match (TRIAL_DIR_REGEX , t )
4548 ]
4649 for trial in trial_dirs :
47- filename = get_filename (FLAGS .trial_dir )
50+ trial_dir = os .path .join (FLAGS .experiment_dir , study_dir , workload , trial )
51+ print (trial_dir )
52+ filename = get_filename (trial_dir )
53+ if not os .path .exists (filename ):
54+ continue
55+
4856 # Start a new W&B run
49- run = wandb .init (project = "visualize-training-curve" , name = (f'{ workload } _{ study_dir } _{ trial } ' + FLAGS .run_postfix ))
57+ run = wandb .init (project = FLAGS . project_name , name = (f'{ workload } _{ study_dir } _{ trial } ' + FLAGS .run_postfix ))
5058
5159 # Log the CSV as a versioned Artifact
5260 artifact = wandb .Artifact (name = "training-data" , type = "dataset" )
0 commit comments