Skip to content

Commit eff5343

Browse files
committed
add self-tuning option to make_job_config.py
1 parent 9c8929f commit eff5343

File tree

1 file changed

+30
-13
lines changed

1 file changed

+30
-13
lines changed

scoring/utils/slurm/make_job_config.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
import jax
77

88
SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py'
9+
EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2'
910
TUNING_SEARCH_SPACE = None
10-
EXPERIMENT_DIR = 'submissions/rolling_leaderboard/external_tuning/shampoo'
1111
FRAMEWORK = 'pytorch'
12+
TUNING_RULESET = 'self'
1213

1314
flags.DEFINE_string('submission_path',
1415
SUBMISSION_PATH,
@@ -27,7 +28,7 @@
2728
flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.')
2829
flags.DEFINE_enum(
2930
'tuning_ruleset',
30-
'external',
31+
TUNING_RULESET,
3132
enum_values=['external', 'self'],
3233
help='Which tuning ruleset to score this submission on. Can be external or self.'
3334
)
@@ -62,12 +63,34 @@ def main(_):
6263
workload_key = jax.random.fold_in(key, hash(workload) % (2**32 - 1))
6364
for study_index in range(NUM_STUDIES):
6465
study_key = jax.random.fold_in(workload_key, study_index)
65-
for hparam_index in range(NUM_TUNING_TRIALS):
66-
run_key = jax.random.fold_in(study_key, hparam_index)
66+
if FLAGS.tuning_ruleset == 'external':
67+
for hparam_index in range(NUM_TUNING_TRIALS):
68+
run_key = jax.random.fold_in(study_key, hparam_index)
69+
seed = jax.random.randint(run_key, (1,), MIN_INT, MAX_INT)[0].item()
70+
print(seed)
71+
# Add job
72+
job = {}
73+
study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
74+
job['framework'] = FLAGS.framework
75+
job['workload'] = workload
76+
job['dataset'] = WORKLOADS[workload]['dataset']
77+
job['submission_path'] = FLAGS.submission_path
78+
job['experiment_dir'] = study_dir
79+
job['rng_seed'] = seed
80+
job['tuning_ruleset'] = FLAGS.tuning_ruleset
81+
job['num_tuning_trials'] = NUM_TUNING_TRIALS
82+
job['hparam_start_index'] = hparam_index
83+
job['hparam_end_index'] = hparam_index + 1
84+
job['tuning_search_space'] = FLAGS.tuning_search_space
85+
job['tuning_ruleset'] = FLAGS.tuning_ruleset
86+
jobs.append(job)
87+
print(job)
88+
89+
else:
90+
run_key = study_key
6791
seed = jax.random.randint(run_key, (1,), MIN_INT, MAX_INT)[0].item()
6892
print(seed)
69-
70-
# Add workload
93+
# Add job
7194
job = {}
7295
study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
7396
job['framework'] = FLAGS.framework
@@ -77,13 +100,7 @@ def main(_):
77100
job['experiment_dir'] = study_dir
78101
job['rng_seed'] = seed
79102
job['tuning_ruleset'] = FLAGS.tuning_ruleset
80-
if FLAGS.tuning_ruleset == 'external':
81-
job['num_tuning_trials'] = NUM_TUNING_TRIALS
82-
job['hparam_start_index'] = hparam_index
83-
job['hparam_end_index'] = hparam_index + 1
84-
job['tuning_search_space'] = FLAGS.tuning_search_space
85-
else:
86-
job['num_tuning_trials'] = 1
103+
job['num_tuning_trials'] = 1
87104

88105
jobs.append(job)
89106
print(job)

0 commit comments

Comments
 (0)