Skip to content

Commit ae8ca68

Browse files
Merge pull request #864 from mlcommons/dropout_jax
Dropout JAX -> dropout_support
2 parents 3ac97ae + c2f4ed0 commit ae8ca68

File tree

22 files changed

+542
-421
lines changed

22 files changed

+542
-421
lines changed

algoperf/jax_utils.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from collections.abc import Sequence
2+
3+
import flax.linen as nn
4+
from flax.linen.module import compact
5+
from flax.linen.module import merge_param
6+
from flax.linen.module import Module
7+
from flax.typing import PRNGKey
8+
import jax
9+
from jax import lax
10+
from jax import random
11+
import jax.numpy as jnp
12+
13+
14+
# Custom Layers
15+
class Dropout(Module):
16+
# pylint: disable=line-too-long
17+
"""Create a dropout layer.
18+
Forked from
19+
https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
20+
The reference dropout implementation is modified support changes
21+
to dropout rate during training by:
22+
1) adding rate argument to the __call__ method.
23+
2) removing the if-else condition to check for edge cases, which
24+
will trigger a recompile for jitted code.
25+
26+
.. note::
27+
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
28+
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
29+
variable initialization.
30+
31+
Example usage::
32+
33+
>>> import flax.linen as nn
34+
>>> import jax, jax.numpy as jnp
35+
36+
>>> class MLP(nn.Module):
37+
... @nn.compact
38+
... def __call__(self, x, train):
39+
... x = nn.Dense(4)(x)
40+
... x = nn.Dropout(0.5, deterministic=not train)(x)
41+
... return x
42+
43+
>>> model = MLP()
44+
>>> x = jnp.ones((1, 3))
45+
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
46+
>>> model.apply(variables, x, train=False) # don't use dropout
47+
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
48+
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
49+
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
50+
51+
Attributes:
52+
rate: the dropout probability. (_not_ the keep rate!)
53+
broadcast_dims: dimensions that will share the same dropout mask
54+
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
55+
and masked, whereas if true, no mask is applied and the inputs are
56+
returned as is.
57+
rng_collection: the rng collection name to use when requesting an rng
58+
key.
59+
"""
60+
61+
rate: float | None = None
62+
broadcast_dims: Sequence[int] = ()
63+
deterministic: bool | None = None
64+
rng_collection: str = "dropout"
65+
legacy: bool = True
66+
67+
@compact
68+
def __call__(
69+
self,
70+
inputs,
71+
deterministic: bool | None = None,
72+
rate: float | None = None,
73+
rng: PRNGKey | None = None,
74+
):
75+
"""Applies a random dropout mask to the input.
76+
77+
Args:
78+
inputs: the inputs that should be randomly masked.
79+
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
80+
and masked, whereas if true, no mask is applied and the inputs are
81+
returned as is.
82+
rate: the dropout probability. (_not_ the keep rate!)
83+
rng: an optional PRNGKey used as the random key, if not specified,
84+
one will be generated using ``make_rng`` with the
85+
``rng_collection`` name.
86+
87+
Returns:
88+
The masked inputs reweighted to preserve mean.
89+
"""
90+
deterministic = merge_param("deterministic",
91+
self.deterministic,
92+
deterministic)
93+
94+
# Override self.rate if rate is passed to __call__
95+
if rate is None:
96+
rate = self.rate
97+
98+
if self.legacy:
99+
if rate == 0.0:
100+
return inputs
101+
102+
# Prevent gradient NaNs in 1.0 edge-case.
103+
if rate == 1.0:
104+
return jnp.zeros_like(inputs)
105+
106+
if deterministic:
107+
return inputs
108+
109+
keep_prob = 1.0 - rate
110+
if rng is None:
111+
rng = self.make_rng(self.rng_collection)
112+
broadcast_shape = list(inputs.shape)
113+
for dim in self.broadcast_dims:
114+
broadcast_shape[dim] = 1
115+
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
116+
mask = jnp.broadcast_to(mask, inputs.shape)
117+
return lax.select(mask, inputs, jnp.zeros_like(inputs))
118+
119+
120+
# Utilities for debugging
121+
def print_jax_model_summary(model, fake_inputs):
122+
"""Prints a summary of the jax module."""
123+
tabulate_fn = nn.tabulate(
124+
model,
125+
jax.random.PRNGKey(0),
126+
console_kwargs={
127+
"force_terminal": False, "force_jupyter": False, "width": 240
128+
},
129+
)
130+
print(tabulate_fn(fake_inputs, train=False))

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,8 @@ def sync_batch_stats(
7979
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
8080
return new_model_state
8181

82-
def init_model_fn(
83-
self,
84-
rng: spec.RandomState,
85-
dropout_rate: Optional[float] = None,
86-
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
82+
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
8783
"""Dropout is unused."""
88-
del dropout_rate
89-
del aux_dropout_rate
9084
model_cls = getattr(models, 'ResNet18')
9185
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
9286
self._model = model

algoperf/workloads/criteo1tb/criteo1tb_jax/models.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""A JAX implementation of DLRM-Small."""
2-
32
from typing import Sequence
43

54
import flax.linen as nn
65
from jax import nn as jnn
76
import jax.numpy as jnp
87

8+
from algoperf.jax_utils import Dropout
9+
10+
DROPOUT_RATE = 0.0
11+
912

1013
class DLRMResNet(nn.Module):
1114
"""Define a DLRMResNet model.
@@ -23,12 +26,13 @@ class DLRMResNet(nn.Module):
2326
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
2427
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
2528
embed_dim: int = 128
26-
dropout_rate: float = 0.0
29+
dropout_rate: float = DROPOUT_RATE
2730
use_layer_norm: bool = False # Unused.
2831
embedding_init_multiplier: float = None # Unused
2932

3033
@nn.compact
31-
def __call__(self, x, train):
34+
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
35+
3236
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
3337
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
3438

@@ -88,8 +92,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
8892
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
8993
top_mlp_input)
9094
x = nn.relu(x)
91-
if self.dropout_rate and layer_idx == num_layers_top - 2:
92-
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
95+
if dropout_rate and layer_idx == num_layers_top - 2:
96+
x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
9397
top_mlp_input += x
9498
# In the DLRM model the last layer width is always 1. We can hardcode that
9599
# below.
@@ -151,7 +155,8 @@ class DlrmSmall(nn.Module):
151155
embedding_init_multiplier: float = None
152156

153157
@nn.compact
154-
def __call__(self, x, train):
158+
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
159+
155160
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
156161
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
157162

@@ -210,10 +215,10 @@ def scaled_init(key, shape, dtype=jnp.float_):
210215
top_mlp_input = nn.relu(top_mlp_input)
211216
if self.use_layer_norm:
212217
top_mlp_input = nn.LayerNorm()(top_mlp_input)
213-
if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
218+
if (dropout_rate is not None and dropout_rate > 0.0 and
214219
layer_idx == num_layers_top - 2):
215-
top_mlp_input = nn.Dropout(
216-
rate=self.dropout_rate, deterministic=not train)(
217-
top_mlp_input)
220+
top_mlp_input = Dropout(
221+
dropout_rate, deterministic=not train)(
222+
top_mlp_input, rate=dropout_rate)
218223
logits = top_mlp_input
219224
return logits

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,36 +72,34 @@ def loss_fn(
7272
def init_model_fn(
7373
self,
7474
rng: spec.RandomState,
75-
dropout_rate: Optional[float] = None,
76-
aux_dropout_rate: Optional[float] = None,
7775
tabulate: Optional[bool] = False,
7876
) -> spec.ModelInitState:
7977
"""Only dropout is used."""
80-
del aux_dropout_rate
8178
if self.use_resnet:
8279
model_class = models.DLRMResNet
8380
else:
8481
model_class = models.DlrmSmall
82+
8583
self._model = model_class(
8684
vocab_size=self.vocab_size,
8785
num_dense_features=self.num_dense_features,
8886
mlp_bottom_dims=self.mlp_bottom_dims,
8987
mlp_top_dims=self.mlp_top_dims,
9088
embed_dim=self.embed_dim,
91-
dropout_rate=dropout_rate,
9289
use_layer_norm=self.use_layer_norm,
9390
embedding_init_multiplier=self.embedding_init_multiplier)
9491

95-
params_rng, dropout_rng = jax.random.split(rng)
92+
params_rng, _ = jax.random.split(rng)
9693
init_fake_batch_size = 2
9794
num_categorical_features = 26
9895
num_dense_features = 13
9996
input_size = num_dense_features + num_categorical_features
10097
input_shape = (init_fake_batch_size, input_size)
10198
init_fn = functools.partial(self._model.init, train=False)
102-
initial_variables = jax.jit(init_fn)(
103-
{'params': params_rng, 'dropout': dropout_rng},
104-
jnp.ones(input_shape, jnp.float32))
99+
initial_variables = jax.jit(init_fn)({
100+
'params': params_rng,
101+
},
102+
jnp.ones(input_shape, jnp.float32))
105103
initial_params = initial_variables['params']
106104
self._param_shapes = param_utils.jax_param_shapes(initial_params)
107105
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -117,14 +115,17 @@ def model_fn(
117115
model_state: spec.ModelAuxiliaryState,
118116
mode: spec.ForwardPassMode,
119117
rng: spec.RandomState,
120-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
118+
update_batch_norm: bool,
119+
dropout_rate: float = models.DROPOUT_RATE
120+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
121121
del model_state
122122
del update_batch_norm
123123
inputs = augmented_and_preprocessed_input_batch['inputs']
124124
train = mode == spec.ForwardPassMode.TRAIN
125125
apply_kwargs = {'train': train}
126126
if train:
127127
apply_kwargs['rngs'] = {'dropout': rng}
128+
apply_kwargs['dropout_rate'] = dropout_rate
128129
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
129130
return logits_batch, None
130131

algoperf/workloads/fastmri/fastmri_jax/models.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,15 @@
1313
github.com/facebookresearch/fastMRI/tree/main/fastmri/data
1414
"""
1515
import functools
16-
from typing import Optional
1716

1817
import flax.linen as nn
1918
import jax
2019
import jax.numpy as jnp
2120

21+
from algoperf.jax_utils import Dropout
22+
23+
DROPOUT_RATE = 0.0
24+
2225

2326
def _instance_norm2d(x, axes, epsilon=1e-5):
2427
# promote x to at least float32, this avoids half precision computation
@@ -56,16 +59,12 @@ class UNet(nn.Module):
5659
num_channels: int = 32
5760
num_pool_layers: int = 4
5861
out_channels = 1
59-
dropout_rate: Optional[float] = 0.0 # If None, defaults to 0.0.
62+
dropout_rate: float = DROPOUT_RATE
6063
use_tanh: bool = False
6164
use_layer_norm: bool = False
6265

6366
@nn.compact
64-
def __call__(self, x, train=True):
65-
dropout_rate = self.dropout_rate
66-
if dropout_rate is None:
67-
dropout_rate = 0.0
68-
67+
def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
6968
# pylint: disable=invalid-name
7069
_ConvBlock = functools.partial(
7170
ConvBlock,
@@ -138,12 +137,12 @@ class ConvBlock(nn.Module):
138137
dropout_rate: Dropout probability.
139138
"""
140139
out_channels: int
141-
dropout_rate: float
142140
use_tanh: bool
143141
use_layer_norm: bool
142+
dropout_rate: float = 0.0
144143

145144
@nn.compact
146-
def __call__(self, x, train=True):
145+
def __call__(self, x, train=True, dropout_rate=DROPOUT_RATE):
147146
"""Forward function.
148147
Note: Pytorch is NCHW and jax/flax is NHWC.
149148
Args:
@@ -172,9 +171,9 @@ def __call__(self, x, train=True):
172171
x = activation_fn(x)
173172
# Ref code uses dropout2d which applies the same mask for the entire channel
174173
# Replicated by using broadcast dims to have the same filter on HW
175-
x = nn.Dropout(
176-
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
177-
x)
174+
x = Dropout(
175+
dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
176+
x, rate=dropout_rate)
178177
x = nn.Conv(
179178
features=self.out_channels,
180179
kernel_size=(3, 3),
@@ -186,9 +185,9 @@ def __call__(self, x, train=True):
186185
else:
187186
x = _instance_norm2d(x, (1, 2))
188187
x = activation_fn(x)
189-
x = nn.Dropout(
190-
self.dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
191-
x)
188+
x = Dropout(
189+
dropout_rate, broadcast_dims=(1, 2), deterministic=not train)(
190+
x, rate=dropout_rate)
192191
return x
193192

194193

0 commit comments

Comments
 (0)