Skip to content

Commit 66f5ed3

Browse files
committed
fix formatting
1 parent 6c7d695 commit 66f5ed3

File tree

1 file changed

+17
-16
lines changed

1 file changed

+17
-16
lines changed

tests/test_jax_utils.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,19 @@
33
Run it as: pytest <path to this module>
44
"""
55

6+
from functools import partial
67
import os
78

89
from absl.testing import absltest
910
from absl.testing import parameterized
11+
import flax.linen as nn
1012
import jax
1113
import jax.numpy as jnp
12-
import flax.linen as nn
14+
from jax.tree_util import tree_leaves
15+
from jax.tree_util import tree_map
16+
from jax.tree_util import tree_structure
1317

14-
from jax.tree_util import tree_structure, tree_leaves, tree_map
1518
from algoperf.jax_utils import Dropout
16-
from functools import partial
17-
1819

1920
SEED = 1996
2021
DEFAULT_DROPOUT = 0.5
@@ -213,25 +214,25 @@ def test_jitted_updates(self, dropout_rate, mode):
213214
jitted_custom_apply = jax.jit(
214215
partial(cust_model.apply), static_argnames=['train'])
215216

216-
217217
def multiple_fwd_passes_custom_layer():
218-
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
219-
y2 = jitted_custom_apply(
220-
initial_variables_custom,
221-
x,
222-
train=train,
223-
dropout_rate=d,
224-
rngs={"dropout": dropout_rng},
225-
)
226-
return y2
218+
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
219+
y2 = jitted_custom_apply(
220+
initial_variables_custom,
221+
x,
222+
train=train,
223+
dropout_rate=d,
224+
rngs={"dropout": dropout_rng},
225+
)
226+
return y2
227227

228228
def multiple_fwd_passes_original_layer():
229-
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
230-
y1 = jitted_original_apply(
229+
for d in [i * 0.1 * dropout_rate for i in range(0, 11)]:
230+
y1 = jitted_original_apply(
231231
initial_variables_original,
232232
x,
233233
train=train,
234234
rngs={"dropout": dropout_rng})
235235

236+
236237
if __name__ == "__main__":
237238
absltest.main()

0 commit comments

Comments
 (0)