Skip to content

Commit 004afbd

Browse files
committed
reformatting
1 parent fe3f9f0 commit 004afbd

File tree

4 files changed

+16
-24
lines changed

4 files changed

+16
-24
lines changed

prize_qualification_baselines/external_tuning/jax_nadamw_full_budget.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
123123
mu_hat = _update_moment(updates, mu, b1, 1)
124124
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
125125
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
126-
updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps),
127-
mu_hat,
128-
nu_hat)
126+
updates = jax.tree.map(
127+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
129128
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
130129

131130
return optax.GradientTransformation(init_fn, update_fn)
@@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):
140139

141140
def _update_moment(updates, moments, decay, order):
142141
"""Compute the exponential moving average of the `order-th` moment."""
143-
return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t,
144-
updates,
145-
moments)
142+
return jax.tree.map(
143+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
146144

147145

148146
def _bias_correction(moment, decay, count):

prize_qualification_baselines/external_tuning/jax_nadamw_target_setting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,8 @@ def update_fn(updates, state, params=None):
123123
mu_hat = _update_moment(updates, mu, b1, 1)
124124
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
125125
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
126-
updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps),
127-
mu_hat,
128-
nu_hat)
126+
updates = jax.tree.map(
127+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
129128
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
130129

131130
return optax.GradientTransformation(init_fn, update_fn)
@@ -140,9 +139,8 @@ class ScaleByAdamState(NamedTuple):
140139

141140
def _update_moment(updates, moments, decay, order):
142141
"""Compute the exponential moving average of the `order-th` moment."""
143-
return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t,
144-
updates,
145-
moments)
142+
return jax.tree.map(
143+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
146144

147145

148146
def _bias_correction(moment, decay, count):

prize_qualification_baselines/self_tuning/jax_nadamw_full_budget.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,8 @@ def update_fn(updates, state, params=None):
132132
mu_hat = _update_moment(updates, mu, b1, 1)
133133
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
134134
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
135-
updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps),
136-
mu_hat,
137-
nu_hat)
135+
updates = jax.tree.map(
136+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
138137
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
139138

140139
return optax.GradientTransformation(init_fn, update_fn)
@@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple):
149148

150149
def _update_moment(updates, moments, decay, order):
151150
"""Compute the exponential moving average of the `order-th` moment."""
152-
return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t,
153-
updates,
154-
moments)
151+
return jax.tree.map(
152+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
155153

156154

157155
def _bias_correction(moment, decay, count):

prize_qualification_baselines/self_tuning/jax_nadamw_target_setting.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,8 @@ def update_fn(updates, state, params=None):
132132
mu_hat = _update_moment(updates, mu, b1, 1)
133133
mu_hat = mu_hat if not debias else _bias_correction(mu_hat, b1, count)
134134
nu_hat = nu if not debias else _bias_correction(nu, b2, count)
135-
updates = jax.tree.map(lambda m, v: m / (raise_power(v + eps_root) + eps),
136-
mu_hat,
137-
nu_hat)
135+
updates = jax.tree.map(
136+
lambda m, v: m / (raise_power(v + eps_root) + eps), mu_hat, nu_hat)
138137
return updates, ScaleByAdamState(count=count, mu=mu, nu=nu)
139138

140139
return optax.GradientTransformation(init_fn, update_fn)
@@ -149,9 +148,8 @@ class ScaleByAdamState(NamedTuple):
149148

150149
def _update_moment(updates, moments, decay, order):
151150
"""Compute the exponential moving average of the `order-th` moment."""
152-
return jax.tree.map(lambda g, t: (1 - decay) * (g**order) + decay * t,
153-
updates,
154-
moments)
151+
return jax.tree.map(
152+
lambda g, t: (1 - decay) * (g**order) + decay * t, updates, moments)
155153

156154

157155
def _bias_correction(moment, decay, count):

0 commit comments

Comments
 (0)