66import jax
77
88SUBMISSION_PATH = 'submissions_algorithms/submissions/self_tuning/schedule_free_adamw_v2/submission.py'
9+ EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/schedule_free_adamw_v2'
910TUNING_SEARCH_SPACE = None
10- EXPERIMENT_DIR = 'submissions/rolling_leaderboard/external_tuning/shampoo'
1111FRAMEWORK = 'pytorch'
12+ TUNING_RULESET = 'self'
1213
1314flags .DEFINE_string ('submission_path' ,
1415 SUBMISSION_PATH ,
2728flags .DEFINE_integer ('seed' , 0 , 'RNG seed to to generate study seeds from.' )
2829flags .DEFINE_enum (
2930 'tuning_ruleset' ,
30- 'external' ,
31+ TUNING_RULESET ,
3132 enum_values = ['external' , 'self' ],
3233 help = 'Which tuning ruleset to score this submission on. Can be external or self.'
3334)
@@ -62,12 +63,34 @@ def main(_):
6263 workload_key = jax .random .fold_in (key , hash (workload ) % (2 ** 32 - 1 ))
6364 for study_index in range (NUM_STUDIES ):
6465 study_key = jax .random .fold_in (workload_key , study_index )
65- for hparam_index in range (NUM_TUNING_TRIALS ):
66- run_key = jax .random .fold_in (study_key , hparam_index )
66+ if FLAGS .tuning_ruleset == 'external' :
67+ for hparam_index in range (NUM_TUNING_TRIALS ):
68+ run_key = jax .random .fold_in (study_key , hparam_index )
69+ seed = jax .random .randint (run_key , (1 ,), MIN_INT , MAX_INT )[0 ].item ()
70+ print (seed )
71+ # Add job
72+ job = {}
73+ study_dir = os .path .join (FLAGS .experiment_dir , f"study_{ study_index } " )
74+ job ['framework' ] = FLAGS .framework
75+ job ['workload' ] = workload
76+ job ['dataset' ] = WORKLOADS [workload ]['dataset' ]
77+ job ['submission_path' ] = FLAGS .submission_path
78+ job ['experiment_dir' ] = study_dir
79+ job ['rng_seed' ] = seed
80+ job ['tuning_ruleset' ] = FLAGS .tuning_ruleset
81+ job ['num_tuning_trials' ] = NUM_TUNING_TRIALS
82+ job ['hparam_start_index' ] = hparam_index
83+ job ['hparam_end_index' ] = hparam_index + 1
84+ job ['tuning_search_space' ] = FLAGS .tuning_search_space
85+ job ['tuning_ruleset' ] = FLAGS .tuning_ruleset
86+ jobs .append (job )
87+ print (job )
88+
89+ else :
90+ run_key = study_key
6791 seed = jax .random .randint (run_key , (1 ,), MIN_INT , MAX_INT )[0 ].item ()
6892 print (seed )
69-
70- # Add workload
93+ # Add job
7194 job = {}
7295 study_dir = os .path .join (FLAGS .experiment_dir , f"study_{ study_index } " )
7396 job ['framework' ] = FLAGS .framework
@@ -77,13 +100,7 @@ def main(_):
77100 job ['experiment_dir' ] = study_dir
78101 job ['rng_seed' ] = seed
79102 job ['tuning_ruleset' ] = FLAGS .tuning_ruleset
80- if FLAGS .tuning_ruleset == 'external' :
81- job ['num_tuning_trials' ] = NUM_TUNING_TRIALS
82- job ['hparam_start_index' ] = hparam_index
83- job ['hparam_end_index' ] = hparam_index + 1
84- job ['tuning_search_space' ] = FLAGS .tuning_search_space
85- else :
86- job ['num_tuning_trials' ] = 1
103+ job ['num_tuning_trials' ] = 1
87104
88105 jobs .append (job )
89106 print (job )
0 commit comments