Skip to content

Commit f1c9920

Browse files
committed
fix formatting
1 parent 5fa41f6 commit f1c9920

File tree

2 files changed

+30
-46
lines changed

2 files changed

+30
-46
lines changed

scoring/score_submissions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def compute_leaderboard_score(df, normalize=True):
160160
def main(_):
161161
results = {}
162162
os.makedirs(FLAGS.output_dir, exist_ok=True)
163+
logging.info(f"Scoring submissions in {FLAGS.submission_directory}")
163164

164165
# Optionally read results to filename
165166
if FLAGS.load_results_from_filename:
@@ -172,8 +173,7 @@ def main(_):
172173
print(submission)
173174
if submission in FLAGS.exclude_submissions.split(','):
174175
continue
175-
experiment_path = os.path.join(FLAGS.submission_directory,
176-
submission)
176+
experiment_path = os.path.join(FLAGS.submission_directory, submission)
177177
df = scoring_utils.get_experiment_df(experiment_path)
178178
results[submission] = df
179179
summary_df = get_submission_summary(df)
Lines changed: 28 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import json
2-
from absl import flags
2+
import os
3+
34
from absl import app
5+
from absl import flags
46
import jax
5-
import os
67

78
SUBMISSION_PATH = '/submissions_algorithms/external_tuning/shampoo_submission/submission.py'
89
TUNING_SEARCH_SPACE = '/submissions_algorithms/external_tuning/shampoo_submission/tuning_search_space.json'
910
EXPERIMENT_DIR = 'submissions/rolling_leaderboard/external_tuning/shampoo'
1011
FRAMEWORK = 'pytorch'
1112

12-
flags.DEFINE_string('submission_path',
13+
flags.DEFINE_string('submission_path',
1314
SUBMISSION_PATH,
1415
'Path to submission module.')
1516
flags.DEFINE_string('tuning_search_space',
@@ -18,63 +19,47 @@
1819
flags.DEFINE_string('experiment_dir',
1920
EXPERIMENT_DIR,
2021
'Path to experiment dir where logs will be saved.')
21-
flags.DEFINE_enum('framework',
22-
FRAMEWORK,
23-
enum_values=['jax', 'pytorch'],
24-
help='Can be either pytorch or jax.')
25-
flags.DEFINE_integer('seed',
26-
0,
27-
'RNG seed to to generate study seeds from.')
28-
flags.DEFINE_enum('tuning_ruleset',
29-
'external',
30-
enum_values=['external', 'self'],
31-
help='Which tuning ruleset to score this submission on. Can be external or self.'
32-
)
22+
flags.DEFINE_enum(
23+
'framework',
24+
FRAMEWORK,
25+
enum_values=['jax', 'pytorch'],
26+
help='Can be either pytorch or jax.')
27+
flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.')
28+
flags.DEFINE_enum(
29+
'tuning_ruleset',
30+
'external',
31+
enum_values=['external', 'self'],
32+
help='Which tuning ruleset to score this submission on. Can be external or self.'
33+
)
3334

3435
FLAGS = flags.FLAGS
3536

36-
3737
MIN_INT = -2**(31)
3838
MAX_INT = 2**(31) - 1
3939
NUM_TUNING_TRIALS = 5 # For external tuning ruleset
4040
NUM_STUDIES = 3
4141

4242
WORKLOADS = {
43-
"imagenet_resnet": {
44-
"dataset": "imagenet"
45-
},
46-
"imagenet_vit": {
47-
"dataset": "imagenet"
48-
},
49-
"fastmri": {
50-
"dataset": "fastmri"
51-
},
52-
"ogbg": {
53-
"dataset": "ogbg"
54-
},
55-
"wmt": {
56-
"dataset": "wmt"
57-
},
58-
"librispeech_deepspeech": {
59-
"dataset": "librispeech"
60-
},
61-
"criteo1tb": {
62-
"dataset": "criteo1tb"
63-
},
64-
"librispeech_conformer": {
65-
"dataset": "librispeech"
66-
}
43+
"imagenet_resnet": {"dataset": "imagenet"},
44+
"imagenet_vit": {"dataset": "imagenet"},
45+
"fastmri": {"dataset": "fastmri"},
46+
"ogbg": {"dataset": "ogbg"},
47+
"wmt": {"dataset": "wmt"},
48+
"librispeech_deepspeech": {"dataset": "librispeech"},
49+
"criteo1tb": {"dataset": "criteo1tb"},
50+
"librispeech_conformer": {"dataset": "librispeech"}
6751
}
6852

53+
6954
def main(_):
70-
workloads=WORKLOADS.keys()
55+
workloads = WORKLOADS.keys()
7156
key = jax.random.key(FLAGS.seed)
7257

7358
jobs = []
7459

7560
for workload in workloads:
7661
# Fold in hash(workload) mod(max(uint32))
77-
workload_key = jax.random.fold_in(key, hash(workload) % (2**32-1))
62+
workload_key = jax.random.fold_in(key, hash(workload) % (2**32 - 1))
7863
for study_index in range(NUM_STUDIES):
7964
study_key = jax.random.fold_in(workload_key, study_index)
8065
for hparam_index in range(NUM_TUNING_TRIALS):
@@ -99,7 +84,6 @@ def main(_):
9984
jobs.append(job)
10085
print(job)
10186

102-
10387
# Convert job array to dict with job indices
10488
job_dict = {}
10589
for i, job in enumerate(jobs):
@@ -110,4 +94,4 @@ def main(_):
11094

11195

11296
if __name__ == '__main__':
113-
app.run(main)
97+
app.run(main)

0 commit comments

Comments
 (0)