Skip to content

Commit c0edfbe

Browse files
committed
plot utils fixes
1 parent 5cbb368 commit c0edfbe

File tree

1 file changed

+19
-11
lines changed

1 file changed

+19
-11
lines changed

scoring/plot_utils/plot_curves.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
import os
99
import wandb
1010

11-
flags.DEFINE_string('experiment_dir', None, 'Path to experiment dir.')
12-
flags.DEFINE_string('workload', None, 'Filter only for workload. If None include all workloads in experiment.')
13-
flags.DEFINE_string('project_name', 'visulaize-training-curves', 'Wandb project name.')
11+
flags.DEFINE_string('experiment_dir', '/home/kasimbeg/algoperf-runs-internal/experiments/pmap_ref', 'Path to experiment dir.')
12+
flags.DEFINE_string('workloads', 'librispeech_conformer_jax', 'Filter only for workload. If None include all workloads in experiment.')
13+
flags.DEFINE_string('project_name', 'visulaize-training-curves-pmap', 'Wandb project name.')
1414
flags.DEFINE_string('run_postfix', '', 'Postfix for wandb runs.')
1515

1616
FLAGS = flags.FLAGS
@@ -27,12 +27,15 @@ def main(_):
2727
experiment_dir = FLAGS.experiment_dir
2828
study_dirs = os.listdir(experiment_dir)
2929
for study_dir in study_dirs:
30-
workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir))
31-
workload_dirs = [
32-
w for w in workload_dirs
33-
if os.path.isdir(os.path.join(experiment_dir, study_dir, w))
34-
]
35-
print(workload_dirs)
30+
if not FLAGS.workloads:
31+
workload_dirs = os.listdir(os.path.join(experiment_dir, study_dir))
32+
workload_dirs = [
33+
w for w in workload_dirs
34+
if os.path.isdir(os.path.join(experiment_dir, study_dir, w))
35+
]
36+
print(workload_dirs)
37+
else:
38+
workload_dirs = FLAGS.workloads.split(',')
3639
for workload in workload_dirs:
3740
data = {
3841
'workload': workload,
@@ -44,9 +47,14 @@ def main(_):
4447
if re.match(TRIAL_DIR_REGEX, t)
4548
]
4649
for trial in trial_dirs:
47-
filename = get_filename(FLAGS.trial_dir)
50+
trial_dir = os.path.join(FLAGS.experiment_dir, study_dir, workload, trial)
51+
print(trial_dir)
52+
filename = get_filename(trial_dir)
53+
if not os.path.exists(filename):
54+
continue
55+
4856
# Start a new W&B run
49-
run = wandb.init(project="visualize-training-curve", name=(f'{workload}_{study_dir}_{trial}' + FLAGS.run_postfix))
57+
run = wandb.init(project=FLAGS.project_name, name=(f'{workload}_{study_dir}_{trial}' + FLAGS.run_postfix))
5058

5159
# Log the CSV as a versioned Artifact
5260
artifact = wandb.Artifact(name="training-data", type="dataset")

0 commit comments

Comments
 (0)