Skip to content

Commit 70705a7

Browse files
committed
merge
1 parent 2cfa2a9 commit 70705a7

File tree

5 files changed

+137
-17
lines changed

5 files changed

+137
-17
lines changed

algoperf/workloads/librispeech_deepspeech/librispeech_jax/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def init_model_fn(
4444
fake_input_batch = [np.zeros((2, *x), jnp.float32) for x in input_shape]
4545

4646
model_init_fn = jax.jit(functools.partial(self._model.init, train=False))
47+
# model_init_fn = functools.partial(self._model.init, train=False)
4748

4849
params_rng, dropout_rng = jax.random.split(rng, 2)
4950
variables = model_init_fn({'params': params_rng, 'dropout': dropout_rng},

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ wmt = ["sentencepiece==0.2.0", "tensorflow-text==2.18.0"]
9999

100100
# Frameworks
101101
jax_core_deps = [
102-
"flax==0.8.4",
102+
"flax==0.10.4",
103103
"optax==0.2.2",
104104
"chex==0.1.86",
105105
"ml_dtypes==0.4.1",

reference_algorithms/paper_baselines/adamw/jax/submission.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@ def _loss_fn(params):
7575
spec.ForwardPassMode.TRAIN,
7676
rng,
7777
update_batch_norm=True,)
78-
jax.debug.print("logits: {logits}", logits=logits)
7978
loss_dict = workload.loss_fn(
8079
label_batch=batch['targets'],
8180
logits_batch=logits,
@@ -222,7 +221,7 @@ def get_batch_size(workload_name):
222221
elif workload_name == 'librispeech_conformer':
223222
return 256
224223
elif workload_name == 'librispeech_deepspeech':
225-
return 256
224+
return 16
226225
elif workload_name == 'ogbg':
227226
return 512
228227
elif workload_name == 'wmt':
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import json
2+
import os
3+
4+
from absl import app
5+
from absl import flags
6+
import jax
7+
8+
SUBMISSION_PATH = 'prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py'
9+
EXPERIMENT_DIR = 'submissions/rolling_leaderboard/self_tuning/baseline'
10+
TUNING_SEARCH_SPACE = None
11+
FRAMEWORK = 'jax'
12+
TUNING_RULESET = 'self'
13+
14+
flags.DEFINE_string('submission_path',
15+
SUBMISSION_PATH,
16+
'Path to submission module.')
17+
flags.DEFINE_string('tuning_search_space',
18+
TUNING_SEARCH_SPACE,
19+
'Path to tuning search space for submission module.')
20+
flags.DEFINE_string('experiment_dir',
21+
EXPERIMENT_DIR,
22+
'Path to experiment dir where logs will be saved.')
23+
flags.DEFINE_enum(
24+
'framework',
25+
FRAMEWORK,
26+
enum_values=['jax', 'pytorch'],
27+
help='Can be either pytorch or jax.')
28+
flags.DEFINE_integer('seed', 0, 'RNG seed to to generate study seeds from.')
29+
flags.DEFINE_enum(
30+
'tuning_ruleset',
31+
TUNING_RULESET,
32+
enum_values=['external', 'self'],
33+
help='Which tuning ruleset to score this submission on. Can be external or self.'
34+
)
35+
36+
FLAGS = flags.FLAGS
37+
38+
MIN_INT = -2**(31)
39+
MAX_INT = 2**(31) - 1
40+
NUM_TUNING_TRIALS = 5 # For external tuning ruleset
41+
NUM_STUDIES = 3
42+
43+
WORKLOADS = {
44+
"imagenet_resnet": {"dataset": "imagenet"},
45+
"imagenet_vit": {"dataset": "imagenet"},
46+
"fastmri": {"dataset": "fastmri"},
47+
"ogbg": {"dataset": "ogbg"},
48+
"wmt": {"dataset": "wmt"},
49+
"librispeech_deepspeech": {"dataset": "librispeech"},
50+
"criteo1tb": {"dataset": "criteo1tb"},
51+
"librispeech_conformer": {"dataset": "librispeech"}
52+
}
53+
54+
55+
def main(_):
56+
workloads = WORKLOADS.keys()
57+
key = jax.random.key(FLAGS.seed)
58+
59+
jobs = []
60+
61+
for workload in workloads:
62+
# Fold in hash(workload) mod(max(uint32))
63+
workload_key = jax.random.fold_in(key, hash(workload) % (2**32 - 1))
64+
for study_index in range(NUM_STUDIES):
65+
study_key = jax.random.fold_in(workload_key, study_index)
66+
if FLAGS.tuning_ruleset == 'external':
67+
for hparam_index in range(NUM_TUNING_TRIALS):
68+
run_key = jax.random.fold_in(study_key, hparam_index)
69+
seed = jax.random.randint(run_key, (1,), MIN_INT, MAX_INT)[0].item()
70+
print(seed)
71+
# Add job
72+
job = {}
73+
study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
74+
job['framework'] = FLAGS.framework
75+
job['workload'] = workload
76+
job['dataset'] = WORKLOADS[workload]['dataset']
77+
job['submission_path'] = FLAGS.submission_path
78+
job['experiment_dir'] = study_dir
79+
job['rng_seed'] = seed
80+
job['tuning_ruleset'] = FLAGS.tuning_ruleset
81+
job['num_tuning_trials'] = NUM_TUNING_TRIALS
82+
job['hparam_start_index'] = hparam_index
83+
job['hparam_end_index'] = hparam_index + 1
84+
job['tuning_search_space'] = FLAGS.tuning_search_space
85+
job['tuning_ruleset'] = FLAGS.tuning_ruleset
86+
jobs.append(job)
87+
print(job)
88+
89+
else:
90+
run_key = study_key
91+
seed = jax.random.randint(run_key, (1,), MIN_INT, MAX_INT)[0].item()
92+
print(seed)
93+
# Add job
94+
job = {}
95+
study_dir = os.path.join(FLAGS.experiment_dir, f"study_{study_index}")
96+
job['framework'] = FLAGS.framework
97+
job['workload'] = workload
98+
job['dataset'] = WORKLOADS[workload]['dataset']
99+
job['submission_path'] = FLAGS.submission_path
100+
job['experiment_dir'] = study_dir
101+
job['rng_seed'] = seed
102+
job['tuning_ruleset'] = FLAGS.tuning_ruleset
103+
job['num_tuning_trials'] = 1
104+
105+
jobs.append(job)
106+
print(job)
107+
108+
# Convert job array to dict with job indices
109+
job_dict = {}
110+
for i, job in enumerate(jobs):
111+
job_dict[f"{i}"] = job
112+
113+
with open('config.json', 'w') as f:
114+
json.dump(job_dict, f, indent=4)
115+
116+
117+
if __name__ == '__main__':
118+
app.run(main)

submission_runner.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -636,20 +636,22 @@ def score_submission_on_workload(workload: spec.Workload,
636636
tuning_search_space[hi] = hyperparameters
637637

638638
with profiler.profile('Train'):
639-
timing, metrics = train_once(workload, workload_name,
640-
global_batch_size,
641-
global_eval_batch_size,
642-
data_dir, imagenet_v2_data_dir,
643-
init_optimizer_state,
644-
update_params, data_selection,
645-
prepare_for_eval,
646-
hyperparameters,
647-
rng_seed,
648-
rng,
649-
profiler,
650-
max_global_steps,
651-
tuning_dir_name,
652-
save_checkpoints=save_checkpoints,)
639+
with jax.profiler.trace("/logs/tensorboard"):
640+
print('profiling!')
641+
timing, metrics = train_once(workload, workload_name,
642+
global_batch_size,
643+
global_eval_batch_size,
644+
data_dir, imagenet_v2_data_dir,
645+
init_optimizer_state,
646+
update_params, data_selection,
647+
prepare_for_eval,
648+
hyperparameters,
649+
rng_seed,
650+
rng,
651+
profiler,
652+
max_global_steps,
653+
tuning_dir_name,
654+
save_checkpoints=save_checkpoints,)
653655
all_timings[hi] = timing
654656
all_metrics[hi] = metrics
655657
logging.info(f'Tuning trial {hi + 1}/{num_tuning_trials}')

0 commit comments

Comments
 (0)