|
3 | 3 | Run it as: pytest <path to this module> |
4 | 4 | """ |
5 | 5 |
|
| 6 | +from functools import partial |
6 | 7 | import os |
7 | 8 |
|
8 | 9 | from absl.testing import absltest |
9 | 10 | from absl.testing import parameterized |
| 11 | +import flax.linen as nn |
10 | 12 | import jax |
11 | 13 | import jax.numpy as jnp |
12 | | -import flax.linen as nn |
| 14 | +from jax.tree_util import tree_leaves |
| 15 | +from jax.tree_util import tree_map |
| 16 | +from jax.tree_util import tree_structure |
13 | 17 |
|
14 | | -from jax.tree_util import tree_structure, tree_leaves, tree_map |
15 | 18 | from algoperf.jax_utils import Dropout |
16 | | -from functools import partial |
17 | | - |
18 | 19 |
|
19 | 20 | SEED = 1996 |
20 | 21 | DEFAULT_DROPOUT = 0.5 |
@@ -213,25 +214,25 @@ def test_jitted_updates(self, dropout_rate, mode): |
213 | 214 | jitted_custom_apply = jax.jit( |
214 | 215 | partial(cust_model.apply), static_argnames=['train']) |
215 | 216 |
|
216 | | - |
217 | 217 | def multiple_fwd_passes_custom_layer(): |
218 | | - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
219 | | - y2 = jitted_custom_apply( |
220 | | - initial_variables_custom, |
221 | | - x, |
222 | | - train=train, |
223 | | - dropout_rate=d, |
224 | | - rngs={"dropout": dropout_rng}, |
225 | | - ) |
226 | | - return y2 |
| 218 | + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
| 219 | + y2 = jitted_custom_apply( |
| 220 | + initial_variables_custom, |
| 221 | + x, |
| 222 | + train=train, |
| 223 | + dropout_rate=d, |
| 224 | + rngs={"dropout": dropout_rng}, |
| 225 | + ) |
| 226 | + return y2 |
227 | 227 |
|
228 | 228 | def multiple_fwd_passes_original_layer(): |
229 | | - for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
230 | | - y1 = jitted_original_apply( |
| 229 | + for d in [i * 0.1 * dropout_rate for i in range(0, 11)]: |
| 230 | + y1 = jitted_original_apply( |
231 | 231 | initial_variables_original, |
232 | 232 | x, |
233 | 233 | train=train, |
234 | 234 | rngs={"dropout": dropout_rng}) |
235 | 235 |
|
| 236 | + |
236 | 237 | if __name__ == "__main__": |
237 | 238 | absltest.main() |
0 commit comments