11import json
2- from absl import flags
2+ import os
3+
34from absl import app
5+ from absl import flags
46import jax
5- import os
67
78SUBMISSION_PATH = '/submissions_algorithms/external_tuning/shampoo_submission/submission.py'
89TUNING_SEARCH_SPACE = '/submissions_algorithms/external_tuning/shampoo_submission/tuning_search_space.json'
910EXPERIMENT_DIR = 'submissions/rolling_leaderboard/external_tuning/shampoo'
1011FRAMEWORK = 'pytorch'
1112
12- flags .DEFINE_string ('submission_path' ,
13+ flags .DEFINE_string ('submission_path' ,
1314 SUBMISSION_PATH ,
1415 'Path to submission module.' )
1516flags .DEFINE_string ('tuning_search_space' ,
1819flags .DEFINE_string ('experiment_dir' ,
1920 EXPERIMENT_DIR ,
2021 'Path to experiment dir where logs will be saved.' )
21- flags .DEFINE_enum ('framework' ,
22- FRAMEWORK ,
23- enum_values = [ 'jax' , 'pytorch' ] ,
24- help = 'Can be either pytorch or jax.' )
25- flags . DEFINE_integer ( 'seed' ,
26- 0 ,
27- 'RNG seed to to generate study seeds from.' )
28- flags . DEFINE_enum ( 'tuning_ruleset' ,
29- 'external' ,
30- enum_values = ['external' , 'self' ],
31- help = 'Which tuning ruleset to score this submission on. Can be external or self.'
32- )
22+ flags .DEFINE_enum (
23+ 'framework' ,
24+ FRAMEWORK ,
25+ enum_values = [ ' jax' , 'pytorch' ],
26+ help = 'Can be either pytorch or jax.' )
27+ flags . DEFINE_integer ( 'seed' , 0 , 'RNG seed to to generate study seeds from.' )
28+ flags . DEFINE_enum (
29+ 'tuning_ruleset' ,
30+ 'external' ,
31+ enum_values = ['external' , 'self' ],
32+ help = 'Which tuning ruleset to score this submission on. Can be external or self.'
33+ )
3334
3435FLAGS = flags .FLAGS
3536
36-
3737MIN_INT = - 2 ** (31 )
3838MAX_INT = 2 ** (31 ) - 1
3939NUM_TUNING_TRIALS = 5 # For external tuning ruleset
4040NUM_STUDIES = 3
4141
4242WORKLOADS = {
43- "imagenet_resnet" : {
44- "dataset" : "imagenet"
45- },
46- "imagenet_vit" : {
47- "dataset" : "imagenet"
48- },
49- "fastmri" : {
50- "dataset" : "fastmri"
51- },
52- "ogbg" : {
53- "dataset" : "ogbg"
54- },
55- "wmt" : {
56- "dataset" : "wmt"
57- },
58- "librispeech_deepspeech" : {
59- "dataset" : "librispeech"
60- },
61- "criteo1tb" : {
62- "dataset" : "criteo1tb"
63- },
64- "librispeech_conformer" : {
65- "dataset" : "librispeech"
66- }
43+ "imagenet_resnet" : {"dataset" : "imagenet" },
44+ "imagenet_vit" : {"dataset" : "imagenet" },
45+ "fastmri" : {"dataset" : "fastmri" },
46+ "ogbg" : {"dataset" : "ogbg" },
47+ "wmt" : {"dataset" : "wmt" },
48+ "librispeech_deepspeech" : {"dataset" : "librispeech" },
49+ "criteo1tb" : {"dataset" : "criteo1tb" },
50+ "librispeech_conformer" : {"dataset" : "librispeech" }
6751}
6852
53+
6954def main (_ ):
70- workloads = WORKLOADS .keys ()
55+ workloads = WORKLOADS .keys ()
7156 key = jax .random .key (FLAGS .seed )
7257
7358 jobs = []
7459
7560 for workload in workloads :
7661 # Fold in hash(workload) mod(max(uint32))
77- workload_key = jax .random .fold_in (key , hash (workload ) % (2 ** 32 - 1 ))
62+ workload_key = jax .random .fold_in (key , hash (workload ) % (2 ** 32 - 1 ))
7863 for study_index in range (NUM_STUDIES ):
7964 study_key = jax .random .fold_in (workload_key , study_index )
8065 for hparam_index in range (NUM_TUNING_TRIALS ):
@@ -99,7 +84,6 @@ def main(_):
9984 jobs .append (job )
10085 print (job )
10186
102-
10387 # Convert job array to dict with job indices
10488 job_dict = {}
10589 for i , job in enumerate (jobs ):
@@ -110,4 +94,4 @@ def main(_):
11094
11195
11296if __name__ == '__main__' :
113- app .run (main )
97+ app .run (main )
0 commit comments