Skip to content

Commit fe3f9f0

Browse files
committed
reformatting
1 parent fb62eae commit fe3f9f0

File tree

5 files changed

+40
-36
lines changed

5 files changed

+40
-36
lines changed

reference_algorithms/paper_baselines/adamw/jax/submission.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,16 @@ def jax_cosine_warmup(step_hint: int, hyperparameters):
5555

5656
return optimizer_state, opt_update_fn
5757

58+
5859
def train_step(workload,
59-
opt_update_fn,
60-
model_state,
61-
optimizer_state,
62-
current_param_container,
63-
batch,
64-
rng,
65-
grad_clip,
66-
label_smoothing):
60+
opt_update_fn,
61+
model_state,
62+
optimizer_state,
63+
current_param_container,
64+
batch,
65+
rng,
66+
grad_clip,
67+
label_smoothing):
6768

6869
def _loss_fn(params):
6970
"""Loss function used for training."""
@@ -163,8 +164,7 @@ def update_params(
163164
static_argnums=(0, 1),
164165
donate_argnums=(2, 3, 4),
165166
in_shardings=arg_shardings,
166-
out_shardings=out_shardings
167-
)
167+
out_shardings=out_shardings)
168168
new_optimizer_state, new_params, new_model_state, loss, grad_norm = jitted_train_step(workload,
169169
opt_update_fn,
170170
model_state,

reference_algorithms/paper_baselines/nadamw/jax/submission.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
123123
mu_hat = _update_moment(updates, mu, b1, 1)
124124
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
125125
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
126-
updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps),
127-
mu_hat,
128-
nu_hat)
126+
updates = jax.tree.map(
127+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
129128
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
130129

131130
return optax.GradientTransformation(init_fn, update_fn)
@@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):
140139

141140
def _update_moment(updates, moments, decay, order):
142141
"""Compute the exponential moving average of the `order-th` moment."""
143-
return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t,
144-
updates,
145-
moments)
142+
return jax.tree.map(
143+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
146144

147145

148146
def _bias_correction(moment, decay, count):

reference_algorithms/paper_baselines/sam/jax/submission.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,8 @@ def update_fn(updates, state, grad_fn_params_tuple):
6767
# the noised parameters in the same order as on the original gradients and
6868
# with the same 1e-6 epsilon that is used when clipping the gradients.
6969
updates = dual_vector(updates)
70-
noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u,
71-
params,
72-
updates)
70+
noised_params = jax.tree_util.tree_map(
71+
lambda p, u: p + rho * u, params, updates)
7372
(_, (n_valid_examples, _)), updates = grad_fn(noised_params)
7473
# Get correct global mean grad.
7574
(n_valid_examples, updates) = lax.psum((n_valid_examples, updates),
@@ -81,7 +80,8 @@ def update_fn(updates, state, grad_fn_params_tuple):
8180
sum(jnp.sum(g**2) for g in jax.tree_util.tree_leaves(updates)))
8281
scaled_updates = jax.tree.map(
8382
lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates)
84-
updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates,
83+
updates = jax.lax.cond(updates_norm > grad_clip,
84+
lambda _: scaled_updates,
8585
lambda _: updates,
8686
None)
8787
updates, state = base_opt_update_fn(updates, state, params)

reference_algorithms/paper_baselines/shampoo/jax/distributed_shampoo.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -595,8 +595,8 @@ def matrix_inverse_pth_root(
595595

596596
if padding_start is not None:
597597
# Zero out padding in identity as well for convergence checks.
598-
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
599-
< padding_start).astype(matrix.dtype)
598+
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
599+
matrix.dtype)
600600
matrix *= ix[jnp.newaxis, :]
601601
matrix *= ix[:, jnp.newaxis]
602602
identity *= ix
@@ -815,8 +815,8 @@ def matrix_inverse_pth_root_eigh(
815815
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
816816
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
817817
if padding_start is not None:
818-
ix = (jnp.arange(matrix_size, dtype=jnp.int32)
819-
< padding_start).astype(matrix.dtype)
818+
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
819+
matrix.dtype)
820820
matrix *= ix[jnp.newaxis, :]
821821
matrix *= ix[:, jnp.newaxis]
822822
identity *= ix
@@ -1809,13 +1809,17 @@ def sharded_update_fn(grads, state, params):
18091809
))
18101810

18111811
new_stats_flat = jax.tree.map(
1812-
lambda g, s, p: _compute_stats(g, s, p, state.count),
1812+
lambda g,
1813+
s,
1814+
p: _compute_stats(g, s, p, state.count),
18131815
grads_flat,
18141816
stats_flat,
18151817
params_flat)
18161818

18171819
outputs = jax.tree.map(
1818-
lambda g, s, p: _transform_grad(g, s, p, state.count),
1820+
lambda g,
1821+
s,
1822+
p: _transform_grad(g, s, p, state.count),
18191823
grads_flat,
18201824
new_stats_flat,
18211825
params_flat)
@@ -1919,8 +1923,8 @@ def _internal_inverse_pth_root_all():
19191923
errors = metrics.inverse_pth_root_errors
19201924
errors = errors.reshape((-1, 1, 1))
19211925
predicate = jnp.logical_or(
1922-
jnp.isnan(errors), errors
1923-
>= inverse_failure_threshold).astype(new_preconditioners.dtype)
1926+
jnp.isnan(errors),
1927+
errors >= inverse_failure_threshold).astype(new_preconditioners.dtype)
19241928
# TODO(rohananil): Check for numerical instabilities.
19251929
new_conditional_preconditioners = (
19261930
predicate * global_stats.preconditioners +
@@ -2438,7 +2442,9 @@ def update_fn(grads, state, params):
24382442
stats_grads = treedef.flatten_up_to(grads_custom)
24392443

24402444
new_stats_flat = jax.tree.map(
2441-
lambda g, s, p: _compute_stats(g, s, p, state.count),
2445+
lambda g,
2446+
s,
2447+
p: _compute_stats(g, s, p, state.count),
24422448
stats_grads,
24432449
stats_flat,
24442450
params_flat)
@@ -2447,7 +2453,9 @@ def update_fn(grads, state, params):
24472453
params_flat,
24482454
state.count)
24492455
outputs = jax.tree.map(
2450-
lambda g, s, p: _transform_grad(g, s, p, state.count),
2456+
lambda g,
2457+
s,
2458+
p: _transform_grad(g, s, p, state.count),
24512459
grads_flat,
24522460
new_stats_flat,
24532461
params_flat)

reference_algorithms/target_setting_algorithms/jax_nadamw.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,8 @@ def update_fn(updates, state, params=None):
108108
mu_hat = _update_moment(updates, mu, b1, 1)
109109
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
110110
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
111-
updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps),
112-
mu_hat,
113-
nu_hat)
111+
updates = jax.tree.map(
112+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
114113
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
115114

116115
return optax.GradientTransformation(init_fn, update_fn)
@@ -125,9 +124,8 @@ class ScaleByAdamState(NamedTuple):
125124

126125
def _update_moment(updates, moments, decay, order):
127126
"""Compute the exponential moving average of the `order-th` moment."""
128-
return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t,
129-
updates,
130-
moments)
127+
return jax.tree.map(
128+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
131129

132130

133131
def _bias_correction(moment, decay, count):

0 commit comments

Comments
 (0)