You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs_nnx/mnist_tutorial.md
+28-25Lines changed: 28 additions & 25 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -26,15 +26,15 @@ Let’s get started!
26
26
27
27
If `flax` is not installed in your Python environment, use `pip` to install the package from PyPI (below, just uncomment the code in the cell if you are working from Google Colab/Jupyter Notebook):
28
28
29
-
```{code-cell}
29
+
```{code-cell} ipython3
30
30
# !pip install flax
31
31
```
32
32
33
33
## 2. Load the MNIST dataset
34
34
35
35
First, you need to load the MNIST dataset and then prepare the training and testing sets via Tensorflow Datasets (TFDS). You normalize image values, shuffle the data and divide it into batches, and prefetch samples to enhance performance.
36
36
37
-
```{code-cell}
37
+
```{code-cell} ipython3
38
38
import tensorflow_datasets as tfds # TFDS to download MNIST.
39
39
import tensorflow as tf # TensorFlow / `tf.data` operations.
x = self.avg_pool(nnx.relu(self.batch_norm1(self.dropout1(self.conv1(x), rngs=rngs))))
95
96
x = self.avg_pool(nnx.relu(self.batch_norm2(self.conv2(x))))
96
97
x = x.reshape(x.shape[0], -1) # flatten
97
-
x = nnx.relu(self.dropout2(self.linear1(x)))
98
+
x = nnx.relu(self.dropout2(self.linear1(x), rngs=rngs))
98
99
x = self.linear2(x)
99
100
return x
100
101
@@ -108,18 +109,18 @@ nnx.display(model)
108
109
109
110
Let's put the CNN model to the test! Here, you’ll perform a forward pass with arbitrary data and print the results.
110
111
111
-
```{code-cell}
112
+
```{code-cell} ipython3
112
113
import jax.numpy as jnp # JAX NumPy
113
114
114
-
y = model(jnp.ones((1, 28, 28, 1)))
115
+
y = model(jnp.ones((1, 28, 28, 1)), nnx.Rngs(0))
115
116
y
116
117
```
117
118
118
119
## 4. Create the optimizer and define some metrics
119
120
120
121
In Flax NNX, you need to create an `nnx.Optimizer` object to manage the model's parameters and apply gradients during training. `nnx.Optimizer` receives the model's reference, so that it can update its parameters, and an [Optax](https://optax.readthedocs.io/) optimizer to define the update rules. Additionally, you will define an `nnx.MultiMetric` object to keep track of the `Accuracy` and the `Average` loss.
121
122
122
-
```{code-cell}
123
+
```{code-cell} ipython3
123
124
import optax
124
125
125
126
learning_rate = 0.005
@@ -144,31 +145,31 @@ In addition to the `loss`, during training and testing you will also get the `lo
144
145
145
146
During training - the `train_step` - you will use `nnx.value_and_grad` to compute the gradients and update the model's parameters using the `optimizer` you have already defined. And during both training and testing (the `eval_step`), the `loss` and `logits` will be used to calculate the metrics.
146
147
147
-
```{code-cell}
148
-
def loss_fn(model: CNN, batch):
149
-
logits = model(batch['image'])
148
+
```{code-cell} ipython3
149
+
def loss_fn(model: CNN, rngs: nnx.Rngs, batch):
150
+
logits = model(batch['image'], rngs)
150
151
loss = optax.softmax_cross_entropy_with_integer_labels(
In the code above, the [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transformation decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators, such as Google TPUs and GPUs. `nnx.jit` is a "lifted" version of the `jax.jit` transform that allows its function input and outputs to be Flax NNX objects. Similarly, `nnx.value_and_grad ` is a lifted version of `jax.value_and_grad `. Check out [the lifted transforms guide](https://flax.readthedocs.io/en/latest/guides/transforms.html) to learn more.
170
171
171
-
> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html).
172
+
> **Note:** The code shows how to perform several in-place updates to the model, the optimizer, the RNG streams and the metrics, but _state updates_ were not explicitly returned. This is because Flax NNX transformations respect _reference semantics_ for Flax NNX objects, and will propagate the state updates of the objects passed as input arguments. This is a key feature of Flax NNX that allows for a more concise and readable code. You can learn more in [Why Flax NNX](https://flax.readthedocs.io/en/latest/why.html).
172
173
173
174
174
175
## 6. Train and evaluate the model
@@ -177,7 +178,7 @@ Now, you can train the CNN model using batches of data for 10 epochs, evaluate t
177
178
on the test set after each epoch, and log the training and testing metrics (the loss and
178
179
the accuracy) during the process. Typically this leads to the model achieving around 99% accuracy.
179
180
180
-
```{code-cell}
181
+
```{code-cell} ipython3
181
182
from IPython.display import clear_output
182
183
import matplotlib.pyplot as plt
183
184
@@ -188,13 +189,15 @@ metrics_history = {
188
189
'test_accuracy': [],
189
190
}
190
191
192
+
rngs = nnx.Rngs(0)
193
+
191
194
for step, batch in enumerate(train_ds.as_numpy_iterator()):
192
195
# Run the optimization for one step and make a stateful update to the following:
if step > 0 and (step % eval_every == 0 or step == train_steps - 1): # One training epoch has passed.
200
203
# Log the training metrics.
@@ -205,7 +208,7 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
205
208
# Compute the metrics on the test set after each training epoch.
206
209
model.eval() # Switch to eval mode
207
210
for test_batch in test_ds.as_numpy_iterator():
208
-
eval_step(model, metrics, test_batch)
211
+
eval_step(model, metrics, rngs, test_batch)
209
212
210
213
# Log the test metrics.
211
214
for metric, value in metrics.compute().items():
@@ -229,7 +232,7 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()):
229
232
230
233
Create a `jit`-compiled model inference function (with `nnx.jit`) - `pred_step` - to generate predictions on the test set using the learned model parameters. This will enable you to visualize test images alongside their predicted labels for a qualitative assessment of model performance.
We call .eval() before inference so Dropout is disabled and BatchNorm uses stored running stats. It is used during inference to suppress gradients and ensure deterministic, resource-efficient output.
Copy file name to clipboardExpand all lines: docs_nnx/nnx_basics.md
+25-25Lines changed: 25 additions & 25 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -90,31 +90,27 @@ to handle them, as demonstrated in later sections of this guide.
90
90
91
91
Flax `Module`s can be used to compose other Modules in a nested structure. These can be assigned directly as attributes, or inside an attribute of any (nested) pytree type, such as a `list`, `dict`, `tuple`, and so on.
92
92
93
-
The example below shows how to define a simple `MLP` by subclassing `Module`. The model consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer:
93
+
The example below shows how to define a simple `MLP` by subclassing `Module`. The model consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. Note that we need to pass the `__call__` method the RNG state that we want the `Dropout` layer to use.
x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))
103
+
def __call__(self, x: jax.Array, rngs: nnx.Rngs):
104
+
x = nnx.gelu(self.dropout(self.bn(self.linear1(x)), rngs=rngs))
105
105
return self.linear2(x)
106
106
107
107
model = MLP(2, 16, 5, rngs=nnx.Rngs(0))
108
108
109
-
y = model(x=jnp.ones((3, 2)))
109
+
y = model(x=jnp.ones((3, 2)), rngs=nnx.Rngs(1))
110
110
111
111
nnx.display(model)
112
112
```
113
113
114
-
In Flax, `Dropout` is a stateful module that stores an `Rngs` object, so that it can generate new masks during the forward pass without the need for the user to pass a new key each time.
115
-
116
-
+++
117
-
118
114
### Model surgery
119
115
120
116
Flax `Module`s are mutable by default. This means that their structure can be changed at any time, which makes [model surgery](https://flax.readthedocs.io/en/latest/guides/surgery.html) quite easy, as any sub-Module attribute can be replaced with anything else, such as new Modules, existing shared Modules, Modules of different types, and so on. Moreover, `Variable`s can also be modified or replaced/shared.
@@ -140,7 +136,7 @@ model = MLP(2, 32, 5, rngs=rngs)
optimizer.update(model, grads) # In place updates.
171
167
172
168
return loss
173
169
174
170
x, y = jnp.ones((5, 2)), jnp.ones((5, 10))
175
-
loss = train_step(model, optimizer, x, y)
171
+
loss = train_step(model, optimizer, x, y, rngs)
176
172
177
173
print(f'{loss = }')
178
174
print(f'{optimizer.step.value = }')
@@ -194,23 +190,27 @@ In the code below notice the following:
194
190
1. The custom `create_model` function takes in a key and returns an `MLP` object, since you create five keys and use `nnx.vmap` over `create_model` a stack of 5 `MLP` objects is created.
195
191
2. The `nnx.scan` is used to iteratively apply each `MLP` in the stack to the input `x`.
196
192
3. The nnx.scan (consciously) deviates from `jax.lax.scan` and instead mimics nnx.vmap, which is more expressive. nnx.scan allows specifying multiple inputs, the scan axes of each input/output, and the position of the carry.
197
-
4.`State` updates for the `BatchNorm` and `Dropout` layers are automatically propagated by nnx.scan.
193
+
4.`State` updates for `BatchNorm` layers are automatically propagated by nnx.scan.
194
+
5. The `rngs` object is split into separate streams for each layer using the `fork` method.
0 commit comments