Skip to content

Commit 09aca7f

Browse files
committed
Fix formatting
1 parent 8bca401 commit 09aca7f

File tree

2 files changed

+157
-174
lines changed

2 files changed

+157
-174
lines changed

algoperf/jax_utils.py

Lines changed: 74 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,95 +1,92 @@
11
from collections.abc import Sequence
22

33
import flax.linen as nn
4-
from flax.linen.module import compact
5-
from flax.linen.module import merge_param
6-
from flax.linen.module import Module
7-
from flax.typing import PRNGKey
84
import jax
9-
from jax import lax
10-
from jax import random
115
import jax.numpy as jnp
6+
from flax.linen.module import Module, compact, merge_param
7+
from flax.typing import PRNGKey
8+
from jax import lax, random
129

1310

1411
# Custom Layers
1512
class Dropout(Module):
1613
# pylint: disable=line-too-long
1714
"""Create a dropout layer.
18-
Forked from
19-
https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
20-
The reference dropout implementation is modified support changes
21-
to dropout rate during training by:
22-
1) adding rate argument to the __call__ method.
23-
2) removing the if-else condition to check for edge cases, which
24-
will trigger a recompile for jitted code.
25-
26-
.. note::
27-
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
28-
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
29-
variable initialization.
30-
31-
Example usage::
32-
33-
>>> import flax.linen as nn
34-
>>> import jax, jax.numpy as jnp
35-
36-
>>> class MLP(nn.Module):
37-
... @nn.compact
38-
... def __call__(self, x, train):
39-
... x = nn.Dense(4)(x)
40-
... x = nn.Dropout(0.5, deterministic=not train)(x)
41-
... return x
42-
43-
>>> model = MLP()
44-
>>> x = jnp.ones((1, 3))
45-
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
46-
>>> model.apply(variables, x, train=False) # don't use dropout
47-
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
48-
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
49-
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
50-
51-
Attributes:
52-
rate: the dropout probability. (_not_ the keep rate!)
53-
broadcast_dims: dimensions that will share the same dropout mask
54-
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
55-
and masked, whereas if true, no mask is applied and the inputs are
56-
returned as is.
57-
rng_collection: the rng collection name to use when requesting an rng
58-
key.
59-
"""
15+
Forked from
16+
https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
17+
The reference dropout implementation is modified support changes
18+
to dropout rate during training by:
19+
1) adding rate argument to the __call__ method.
20+
2) removing the if-else condition to check for edge cases, which
21+
will trigger a recompile for jitted code.
22+
23+
.. note::
24+
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
25+
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
26+
variable initialization.
27+
28+
Example usage::
29+
30+
>>> import flax.linen as nn
31+
>>> import jax, jax.numpy as jnp
32+
33+
>>> class MLP(nn.Module):
34+
... @nn.compact
35+
... def __call__(self, x, train):
36+
... x = nn.Dense(4)(x)
37+
... x = nn.Dropout(0.5, deterministic=not train)(x)
38+
... return x
39+
40+
>>> model = MLP()
41+
>>> x = jnp.ones((1, 3))
42+
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
43+
>>> model.apply(variables, x, train=False) # don't use dropout
44+
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
45+
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
46+
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
47+
48+
Attributes:
49+
rate: the dropout probability. (_not_ the keep rate!)
50+
broadcast_dims: dimensions that will share the same dropout mask
51+
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
52+
and masked, whereas if true, no mask is applied and the inputs are
53+
returned as is.
54+
rng_collection: the rng collection name to use when requesting an rng
55+
key.
56+
"""
6057

6158
rate: float | None = None
6259
broadcast_dims: Sequence[int] = ()
6360
deterministic: bool | None = None
64-
rng_collection: str = "dropout"
61+
rng_collection: str = 'dropout'
6562
legacy: bool = False
6663

6764
@compact
6865
def __call__(
69-
self,
70-
inputs,
71-
deterministic: bool | None = None,
72-
rate: float | None = None,
73-
rng: PRNGKey | None = None,
66+
self,
67+
inputs,
68+
deterministic: bool | None = None,
69+
rate: float | None = None,
70+
rng: PRNGKey | None = None,
7471
):
7572
"""Applies a random dropout mask to the input.
7673
77-
Args:
78-
inputs: the inputs that should be randomly masked.
79-
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
80-
and masked, whereas if true, no mask is applied and the inputs are
81-
returned as is.
82-
rate: the dropout probability. (_not_ the keep rate!)
83-
rng: an optional PRNGKey used as the random key, if not specified,
84-
one will be generated using ``make_rng`` with the
85-
``rng_collection`` name.
86-
87-
Returns:
88-
The masked inputs reweighted to preserve mean.
89-
"""
90-
deterministic = merge_param("deterministic",
91-
self.deterministic,
92-
deterministic)
74+
Args:
75+
inputs: the inputs that should be randomly masked.
76+
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
77+
and masked, whereas if true, no mask is applied and the inputs are
78+
returned as is.
79+
rate: the dropout probability. (_not_ the keep rate!)
80+
rng: an optional PRNGKey used as the random key, if not specified,
81+
one will be generated using ``make_rng`` with the
82+
``rng_collection`` name.
83+
84+
Returns:
85+
The masked inputs reweighted to preserve mean.
86+
"""
87+
deterministic = merge_param(
88+
'deterministic', self.deterministic, deterministic
89+
)
9390

9491
# Override self.rate if rate is passed to __call__
9592
if rate is None:
@@ -121,10 +118,12 @@ def __call__(
121118
def print_jax_model_summary(model, fake_inputs):
122119
"""Prints a summary of the jax module."""
123120
tabulate_fn = nn.tabulate(
124-
model,
125-
jax.random.PRNGKey(0),
126-
console_kwargs={
127-
"force_terminal": False, "force_jupyter": False, "width": 240
128-
},
121+
model,
122+
jax.random.PRNGKey(0),
123+
console_kwargs={
124+
'force_terminal': False,
125+
'force_jupyter': False,
126+
'width': 240,
127+
},
129128
)
130129
print(tabulate_fn(fake_inputs, train=False))

0 commit comments

Comments
 (0)