Skip to content

Commit 9666481

Browse files
committed
Refactor pooling operation
- Pooling operation moved out of Linen - Pooling documentation duplicated in nnx
1 parent 3531187 commit 9666481

File tree

17 files changed

+743
-707
lines changed

17 files changed

+743
-707
lines changed

docs_nnx/api_reference/flax.nnx/nn/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ See the `NNX page <https://flax.readthedocs.io/en/latest/nnx/index.html>`__ for
1414
linear
1515
lora
1616
normalization
17+
pooling
1718
recurrent
1819
stochastic
1920

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Pooling
2+
------------------------
3+
4+
.. automodule:: flax.nnx
5+
.. currentmodule:: flax.nnx
6+
7+
.. autofunction:: avg_pool
8+
.. autofunction:: max_pool
9+
.. autofunction:: min_pool
10+
.. autofunction:: pool

docs_nnx/guides/randomness.ipynb

Lines changed: 309 additions & 201 deletions
Large diffs are not rendered by default.

docs_nnx/guides/randomness.md

Lines changed: 135 additions & 90 deletions
Large diffs are not rendered by default.

docs_nnx/index.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,20 +97,20 @@ Basic usage
9797
def __init__(self, din, dmid, dout, rngs: nnx.Rngs):
9898
self.linear = nnx.Linear(din, dmid, rngs=rngs)
9999
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
100-
self.dropout = nnx.Dropout(0.2, rngs=rngs)
100+
self.dropout = nnx.Dropout(0.2)
101101
self.linear_out = nnx.Linear(dmid, dout, rngs=rngs)
102102

103-
def __call__(self, x):
104-
x = nnx.relu(self.dropout(self.bn(self.linear(x))))
103+
def __call__(self, x, rngs):
104+
x = nnx.relu(self.dropout(self.bn(self.linear(x)), rngs=rngs))
105105
return self.linear_out(x)
106106

107107
model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization
108108
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
109109

110110
@nnx.jit # automatic state management for JAX transforms
111-
def train_step(model, optimizer, x, y):
111+
def train_step(model, optimizer, x, y, rngs):
112112
def loss_fn(model):
113-
y_pred = model(x) # call methods directly
113+
y_pred = model(x, rngs) # call methods directly
114114
return ((y_pred - y) ** 2).mean()
115115

116116
loss, grads = nnx.value_and_grad(loss_fn)(model)

docs_nnx/mnist_tutorial.ipynb

Lines changed: 46 additions & 169 deletions
Large diffs are not rendered by default.

docs_nnx/mnist_tutorial.md

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ Let’s get started!
2626

2727
If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):
2828

29-
```{code-cell}
29+
```{code-cell} ipython3
3030
# !pip install flax
3131
```
3232

3333
## 2. Load the MNIST dataset
3434

3535
First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance.
3636

37-
```{code-cell}
37+
```{code-cell} ipython3
3838
import tensorflow_datasets as tfds # TFDS to download MNIST.
3939
import tensorflow as tf # TensorFlow / `tf.data` operations.
4040
@@ -72,29 +72,30 @@ test_ds = test_ds.batch(batch_size, drop_remainder=True).prefetch(1)
7272

7373
Create a CNN for classification with Flax NNX by subclassing `nnx.Module`:
7474

75-
```{code-cell}
75+
```{code-cell} ipython3
7676
from flax import nnx # The Flax NNX API.
7777
from functools import partial
78+
from typing import Optional
7879
7980
class CNN(nnx.Module):
8081
"""A simple CNN model."""
8182
8283
def __init__(self, *, rngs: nnx.Rngs):
8384
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
8485
self.batch_norm1 = nnx.BatchNorm(32, rngs=rngs)
85-
self.dropout1 = nnx.Dropout(rate=0.025, rngs=rngs)
86+
self.dropout1 = nnx.Dropout(rate=0.025)
8687
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
8788
self.batch_norm2 = nnx.BatchNorm(64, rngs=rngs)
8889
self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
8990
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
90-
self.dropout2 = nnx.Dropout(rate=0.025, rngs=rngs)
91+
self.dropout2 = nnx.Dropout(rate=0.025)
9192
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
9293
93-
def __call__(self, x):
94-
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x)))))
94+
def __call__(self, x, rngs: Optional[nnx.Rngs] = None):
95+
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs))))
9596
x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))
9697
x = x.reshape(x.shape[0], -1) # flatten
97-
x = nnx.relu(self.dropout2(self.linear1(x)))
98+
x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs))
9899
x = self.linear2(x)
99100
return x
100101
@@ -108,18 +109,18 @@ nnx.display(model)
108109

109110
Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results.
110111

111-
```{code-cell}
112+
```{code-cell} ipython3
112113
import jax.numpy as jnp # JAX NumPy
113114
114-
y = model(jnp.ones((1, 28, 28, 1)))
115+
y = model(jnp.ones((1, 28, 28, 1)), nnx.Rngs(0))
115116
y
116117
```
117118

118119
## 4. Create the optimizer and define some metrics
119120

120121
In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.
121122

122-
```{code-cell}
123+
```{code-cell} ipython3
123124
import optax
124125
125126
learning_rate = 0.005
@@ -144,31 +145,31 @@ In addition to the `loss`, during training and testing you will also get the `lo
144145

145146
During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics.
146147

147-
```{code-cell}
148-
def loss_fn(model: CNN, batch):
149-
logits = model(batch['image'])
148+
```{code-cell} ipython3
149+
def loss_fn(model: CNN, rngs: nnx.Rngs, batch):
150+
logits = model(batch['image'], rngs)
150151
loss = optax.softmax_cross_entropy_with_integer_labels(
151152
logits=logits, labels=batch['label']
152153
).mean()
153154
return loss, logits
154155
155156
@nnx.jit
156-
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):
157+
def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, rngs: nnx.Rngs, batch):
157158
"""Train for a single step."""
158159
grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
159-
(loss, logits), grads = grad_fn(model, batch)
160+
(loss, logits), grads = grad_fn(model, rngs, batch)
160161
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
161-
optimizer.update(grads) # In-place updates.
162+
optimizer.update(model, grads) # In-place updates.
162163
163164
@nnx.jit
164-
def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):
165-
loss, logits = loss_fn(model, batch)
165+
def eval_step(model: CNN, metrics: nnx.MultiMetric, rngs: nnx.Rngs, batch):
166+
loss, logits = loss_fn(model, rngs, batch)
166167
metrics.update(loss=loss, logits=logits, labels=batch['label']) # In-place updates.
167168
```
168169

169170
In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is a "lifted" version of the `jax.jit` transform that allows its function input and outputs to be Flax NNX objects. Similarly, `nnx.value_and_grad ` is a lifted version of `jax.value_and_grad `. Check out [the lifted transforms guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more.
170171

171-
> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html).
172+
> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, the RNG streams and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html).
172173
173174

174175
## 6. Train and evaluate the model
@@ -177,7 +178,7 @@ Now, you can train the CNN model using batches of data for 10 epochs, evaluate t
177178
on the test set after each epoch, and log the training and testing metrics (the loss and
178179
the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy.
179180

180-
```{code-cell}
181+
```{code-cell} ipython3
181182
from IPython.display import clear_output
182183
import matplotlib.pyplot as plt
183184
@@ -188,13 +189,15 @@ metrics_history = {
188189
'test_accuracy': [],
189190
}
190191
192+
rngs = nnx.Rngs(0)
193+
191194
for step, batch in enumerate(train_ds.as_numpy_iterator()):
192195
# Run the optimization for one step and make a stateful update to the following:
193196
# - The train state's model parameters
194197
# - The optimizer state
195198
# - The training loss and accuracy batch metrics
196199
model.train() # Switch to train mode
197-
train_step(model, optimizer, metrics, batch)
200+
train_step(model, optimizer, metrics, rngs, batch)
198201
199202
if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.
200203
# Log the training metrics.
@@ -205,7 +208,7 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
205208
# Compute the metrics on the test set after each training epoch.
206209
model.eval() # Switch to eval mode
207210
for test_batch in test_ds.as_numpy_iterator():
208-
eval_step(model, metrics, test_batch)
211+
eval_step(model, metrics, rngs, test_batch)
209212
210213
# Log the test metrics.
211214
for metric, value in metrics.compute().items():
@@ -229,7 +232,7 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
229232

230233
Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.
231234

232-
```{code-cell}
235+
```{code-cell} ipython3
233236
model.eval() # Switch to evaluation mode.
234237
235238
@nnx.jit
@@ -240,7 +243,7 @@ def pred_step(model: CNN, batch):
240243

241244
We call .eval() before inference so Dropout is disabled and BatchNorm uses stored running stats. It is used during inference to suppress gradients and ensure deterministic, resource-efficient output.
242245

243-
```{code-cell}
246+
```{code-cell} ipython3
244247
test_batch = test_ds.as_numpy_iterator().next()
245248
pred = pred_step(model, test_batch)
246249

docs_nnx/nnx_basics.ipynb

Lines changed: 67 additions & 86 deletions
Large diffs are not rendered by default.

docs_nnx/nnx_basics.md

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -90,31 +90,27 @@ to handle them, as demonstrated in later sections of this guide.
9090

9191
Flax `Module`s can be used to compose other Modules in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.
9292

93-
The example below shows how to define a simple `MLP` by subclassing `Module`. The model consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer:
93+
The example below shows how to define a simple `MLP` by subclassing `Module`. The model consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. Note that we need to pass the `__call__` method the RNG state that we want the `Dropout` layer to use.
9494

9595
```{code-cell} ipython3
9696
class MLP(nnx.Module):
9797
def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):
9898
self.linear1 = Linear(din, dmid, rngs=rngs)
99-
self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)
99+
self.dropout = nnx.Dropout(rate=0.1)
100100
self.bn = nnx.BatchNorm(dmid, rngs=rngs)
101101
self.linear2 = Linear(dmid, dout, rngs=rngs)
102102
103-
def __call__(self, x: jax.Array):
104-
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
103+
def __call__(self, x: jax.Array, rngs: nnx.Rngs):
104+
x = nnx.gelu(self.dropout(self.bn(self.linear1(x)), rngs=rngs))
105105
return self.linear2(x)
106106
107107
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
108108
109-
y = model(x=jnp.ones((3, 2)))
109+
y = model(x=jnp.ones((3, 2)), rngs=nnx.Rngs(1))
110110
111111
nnx.display(model)
112112
```
113113

114-
In Flax, `Dropout` is a stateful module that stores an `Rngs` object, so that it can generate new masks during the forward pass without the need for the user to pass a new key each time.
115-
116-
+++
117-
118114
### Model surgery
119115

120116
Flax `Module`s are mutable by default. This means that their structure can be changed at any time, which makes [model surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) quite easy, as any sub-Module attribute can be replaced with anything else, such as new Modules, existing shared Modules, Modules of different types, and so on. Moreover, `Variable`s can also be modified or replaced/shared.
@@ -140,7 +136,7 @@ model = MLP(2, 32, 5, rngs=rngs)
140136
model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)
141137
model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)
142138
143-
y = model(x=jnp.ones((3, 2)))
139+
y = model(x=jnp.ones((3, 2)), rngs=rngs)
144140
145141
nnx.display(model)
146142
```
@@ -161,18 +157,18 @@ model = MLP(2, 16, 10, rngs=nnx.Rngs(0))
161157
optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
162158
163159
@nnx.jit # Automatic state management
164-
def train_step(model, optimizer, x, y):
165-
def loss_fn(model: MLP):
166-
y_pred = model(x)
160+
def train_step(model, optimizer, x, y, rngs):
161+
def loss_fn(model: MLP, rngs: nnx.Rngs):
162+
y_pred = model(x, rngs)
167163
return jnp.mean((y_pred - y) ** 2)
168164
169-
loss, grads = nnx.value_and_grad(loss_fn)(model)
165+
loss, grads = nnx.value_and_grad(loss_fn)(model, rngs)
170166
optimizer.update(model, grads) # In place updates.
171167
172168
return loss
173169
174170
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
175-
loss = train_step(model, optimizer, x, y)
171+
loss = train_step(model, optimizer, x, y, rngs)
176172
177173
print(f'{loss = }')
178174
print(f'{optimizer.step.value = }')
@@ -194,23 +190,27 @@ In the code below notice the following:
194190
1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.
195191
2. The `nnx.scan` is used to iteratively apply each `MLP` in the stack to the input `x`.
196192
3. The nnx.scan (consciously) deviates from `jax.lax.scan` and instead mimics nnx.vmap, which is more expressive. nnx.scan allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.
197-
4. `State` updates for the `BatchNorm` and `Dropout` layers are automatically propagated by nnx.scan.
193+
4. `State` updates for `BatchNorm` layers are automatically propagated by nnx.scan.
194+
5. The `rngs` object is split into separate streams for each layer using the `fork` method.
198195

199196
```{code-cell} ipython3
200197
@nnx.vmap(in_axes=0, out_axes=0)
201-
def create_model(key: jax.Array):
202-
return MLP(10, 32, 10, rngs=nnx.Rngs(key))
198+
def create_model(rngs):
199+
return MLP(10, 32, 10, rngs=rngs)
203200
204-
keys = jax.random.split(jax.random.key(0), 5)
205-
model = create_model(keys)
206-
207-
@nnx.scan(in_axes=(0, nnx.Carry), out_axes=nnx.Carry)
208-
def forward(model: MLP, x):
209-
x = model(x)
201+
@nnx.scan(in_axes=(0, 0, nnx.Carry), out_axes=nnx.Carry)
202+
def forward(model: MLP, rngs: nnx.Rngs, x):
203+
x = model(x, rngs)
210204
return x
205+
206+
param_rngs = nnx.Rngs(0).fork(split=5)
207+
model = create_model(param_rngs)
208+
```
211209

210+
```{code-cell} ipython3
212211
x = jnp.ones((3, 10))
213-
y = forward(model, x)
212+
dropout_rngs = nnx.Rngs(1).fork(split=5)
213+
y = forward(model, dropout_rngs, x)
214214
215215
print(f'{y.shape = }')
216216
nnx.display(model)

flax/core/nn/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
swish as swish,
3636
tanh as tanh,
3737
)
38-
from flax.linen.pooling import (avg_pool as avg_pool, max_pool as max_pool)
38+
from flax.pooling import (avg_pool as avg_pool, max_pool as max_pool)
3939
from .attention import (
4040
dot_product_attention as dot_product_attention,
4141
multi_head_dot_product_attention as multi_head_dot_product_attention,

0 commit comments

Comments
 (0)