Skip to content

Commit 3531187

Browse files
author
Flax Authors
committed
Merge pull request #5056 from chapman20j:just_set_mode
PiperOrigin-RevId: 826148879
2 parents 6418755 + fd0b62d commit 3531187

File tree

10 files changed

+632
-0
lines changed

10 files changed

+632
-0
lines changed

examples/lm1b_nnx/models_test.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,102 @@ def test_forward_decode(self):
291291
for output_nnx, output_linen in zip(outputs_nnx, outputs_linen):
292292
assert jnp.allclose(output_nnx, output_linen, atol=1e-5)
293293

294+
def test_forward_eval_set_mode(self):
295+
_, config = get_transformer_config(
296+
axis_rules=default.MeshRules(
297+
embed='model',
298+
mlp='data',
299+
kv=None,
300+
vocab=None,
301+
),
302+
deterministic=True,
303+
decode=False,
304+
)
305+
# Set dropout rates to avoid create dropout states
306+
config.dropout_rate = 0.0
307+
config.attention_dropout_rate = 0.0
308+
309+
model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0)))
310+
_, params_nnx = nnx.split(model_nnx, nnx.Param)
311+
312+
model_linen = TransformerLinen(config)
313+
314+
sample_inputs = random.randint(random.PRNGKey(0), (1, 3), 0, 20)
315+
params_linen = model_linen.init(random.key(0), sample_inputs)['params']
316+
317+
self.transfer_params(config, params_nnx, params_linen)
318+
nnx.update(model_nnx, params_nnx)
319+
320+
det_model = nnx.set_mode(model_nnx, deterministic=True, decode=False)
321+
output_nnx = det_model(sample_inputs)
322+
323+
output_linen: jax.Array = model_linen.apply(
324+
{'params': params_linen}, sample_inputs
325+
)
326+
327+
assert jnp.allclose(output_nnx, output_linen, atol=1e-5)
328+
329+
def test_forward_decode_set_mode(self):
330+
batch_size = 2
331+
332+
_, config = get_transformer_config(
333+
axis_rules=default.MeshRules(
334+
embed='model',
335+
mlp='data',
336+
kv=None,
337+
vocab=None,
338+
),
339+
deterministic=True,
340+
decode=True,
341+
)
342+
# Set dropout rates to avoid create dropout states
343+
config.dropout_rate = 0.0
344+
config.attention_dropout_rate = 0.0
345+
346+
model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0)))
347+
for _path, m in model_nnx.iter_modules():
348+
if isinstance(m, HasCache):
349+
input_shape = (batch_size, config.max_len, config.emb_dim)
350+
m.init_cache(input_shape, dtype=config.dtype)
351+
352+
_, params_nnx, cache_nnx = nnx.split(model_nnx, nnx.Param, nnx.Cache)
353+
354+
model_linen = TransformerLinen(config)
355+
356+
flax_init_inputs = random.randint(
357+
random.PRNGKey(0), (batch_size, config.max_len), 0, config.vocab_size
358+
)
359+
ar_decode_inputs = random.randint(
360+
random.PRNGKey(0), (3, batch_size, 1), 0, config.vocab_size
361+
)
362+
variables = model_linen.init(random.key(0), flax_init_inputs)
363+
params_linen = variables['params']
364+
cache_linen = variables['cache']
365+
366+
self.transfer_params(config, params_nnx, params_linen)
367+
self.transfer_cache(config, cache_nnx, cache_linen)
368+
nnx.update(model_nnx, params_nnx, cache_nnx)
369+
det_model = nnx.set_mode(model_nnx, deterministic=True, decode=True)
370+
371+
outputs_nnx = []
372+
outputs_linen = []
373+
374+
for inputs in ar_decode_inputs:
375+
output_nnx = det_model(inputs)
376+
outputs_nnx.append(output_nnx)
377+
378+
output_linen: jax.Array
379+
for inputs in ar_decode_inputs:
380+
output_linen, updates = model_linen.apply(
381+
{'params': params_linen, 'cache': cache_linen},
382+
inputs,
383+
mutable=['cache'],
384+
)
385+
cache_linen = updates['cache']
386+
outputs_linen.append(output_linen)
387+
388+
for output_nnx, output_linen in zip(outputs_nnx, outputs_linen):
389+
assert jnp.allclose(output_nnx, output_linen, atol=1e-5)
294390

295391
if __name__ == '__main__':
296392
absltest.main()

flax/nnx/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
from .helpers import TrainState as TrainState
4848
from .module import M as M
4949
from .module import Module as Module
50+
from .module import set_mode as set_mode
51+
from .module import train_mode as train_mode
52+
from .module import eval_mode as eval_mode
5053
from .module import iter_children as iter_children, iter_modules as iter_modules
5154
from .graph import merge as merge
5255
from .graph import UpdateContext as UpdateContext

flax/nnx/module.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,138 @@ def eval(self, **attributes):
427427
raise_if_not_found=False,
428428
)
429429

430+
def set_mode(node: A, /, *, only: filterlib.Filter = ..., raise_if_not_found: bool = True, **kwargs) -> A:
431+
"""Creates a new node with static attributes updated according to ``**kwargs``.
432+
433+
The new node contains references to jax arrays in the original node. If a
434+
kwarg is not found in any module, this method raises a ValueError. ``set_mode``
435+
class methods should return any unused kwargs.
436+
437+
Example::
438+
>>> from flax import nnx
439+
...
440+
>>> class Block(nnx.Module):
441+
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
442+
... self.linear = nnx.Linear(din, dout, rngs=rngs)
443+
... self.dropout = nnx.Dropout(0.5, deterministic=False)
444+
... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
445+
...
446+
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
447+
>>> block.dropout.deterministic, block.batch_norm.use_running_average
448+
(False, False)
449+
>>> new_block = nnx.set_mode(block, deterministic=True, use_running_average=True)
450+
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
451+
(True, True)
452+
453+
``Filter``'s can be used to set the attributes of specific Modules::
454+
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
455+
>>> new_block = nnx.set_mode(block, only=nnx.Dropout, deterministic=True)
456+
>>> # Only the dropout will be modified
457+
>>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
458+
(True, False)
459+
460+
Args:
461+
node: the object to create a copy of.
462+
only: Filters to select the Modules to set the attributes of.
463+
**kwargs: The attributes to set.
464+
"""
465+
predicate = filterlib.to_predicate(only)
466+
467+
counts = {k: 0 for k in kwargs}
468+
counts["_set_mode_calls"] = 0
469+
470+
def _set_mode_fn(path, node):
471+
if hasattr(node, 'set_mode') and predicate(path, node):
472+
counts["_set_mode_calls"] += 1
473+
unused = node.set_mode(**kwargs)
474+
for k in unused:
475+
counts[k] += 1
476+
return node
477+
478+
out = graph.recursive_map(_set_mode_fn, node)
479+
480+
if raise_if_not_found:
481+
set_mode_calls = counts.pop("_set_mode_calls")
482+
unused_keys = [k for k, v in counts.items() if v == set_mode_calls]
483+
if unused_keys:
484+
raise ValueError(f"Unused keys found in set_mode: {unused_keys}")
485+
486+
return out
487+
488+
def train_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A:
489+
"""Creates a new node set to training mode.
490+
491+
``train_mode`` uses ``set_mode`` to recursively set attributes ``deterministic=False``
492+
and ``use_running_average=False`` of all nested Modules that have these attributes.
493+
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
494+
Modules.
495+
496+
Example::
497+
>>> from flax import nnx
498+
...
499+
>>> class Block(nnx.Module):
500+
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
501+
... self.linear = nnx.Linear(din, dout, rngs=rngs)
502+
... # initialize Dropout and BatchNorm in eval mode
503+
... self.dropout = nnx.Dropout(0.5, deterministic=True)
504+
... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
505+
...
506+
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
507+
>>> block.dropout.deterministic, block.batch_norm.use_running_average
508+
(True, True)
509+
>>> train_block = nnx.train_mode(block)
510+
>>> train_block.dropout.deterministic, train_block.batch_norm.use_running_average
511+
(False, False)
512+
513+
Args:
514+
**kwargs: additional attributes passed to ``set_attributes``.
515+
"""
516+
return set_mode(
517+
node,
518+
only=only,
519+
raise_if_not_found=False,
520+
deterministic=False,
521+
use_running_average=False,
522+
**kwargs,
523+
)
524+
525+
def eval_mode(node: A, /, *, only: filterlib.Filter = ..., **kwargs) -> A:
526+
"""Creates a new node set to evaluation mode.
527+
528+
``eval_mode`` uses ``set_mode`` to recursively set attributes ``deterministic=True``
529+
and ``use_running_average=True`` of all nested Modules that have these attributes.
530+
Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
531+
Modules.
532+
533+
Example::
534+
>>> from flax import nnx
535+
...
536+
>>> class Block(nnx.Module):
537+
... def __init__(self, din, dout, *, rngs: nnx.Rngs):
538+
... self.linear = nnx.Linear(din, dout, rngs=rngs)
539+
... self.dropout = nnx.Dropout(0.5)
540+
... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
541+
...
542+
>>> block = Block(2, 5, rngs=nnx.Rngs(0))
543+
>>> block.dropout.deterministic, block.batch_norm.use_running_average
544+
(False, False)
545+
>>> eval_block = nnx.eval_mode(block)
546+
>>> eval_block.dropout.deterministic, eval_block.batch_norm.use_running_average
547+
(True, True)
548+
549+
Args:
550+
**kwargs: additional attributes passed to ``set_mode``.
551+
"""
552+
return set_mode(
553+
node,
554+
only=only,
555+
raise_if_not_found=False,
556+
deterministic=True,
557+
use_running_average=True,
558+
**kwargs,
559+
)
560+
561+
430562

431563
def first_from(*args: tp.Optional[A], error_msg: str) -> A:
432564
"""Return the first non-None argument.

flax/nnx/nn/attention.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -638,6 +638,51 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32):
638638
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, dtype))
639639
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
640640

641+
def set_mode(
642+
self,
643+
deterministic: bool | None = None,
644+
decode: bool | None = None,
645+
batch_size: int | Shape | None = None,
646+
max_length: int | None = None,
647+
**kwargs,
648+
) -> dict:
649+
"""Class method used by ``nnx.set_mode``.
650+
651+
Args:
652+
train: if True, the module is set to training mode.
653+
deterministic: if True, the module is set to deterministic mode.
654+
decode: if True, the module is set to decode mode.
655+
batch_size: the batch size to use for the cache.
656+
max_length: the max length to use for the cache.
657+
"""
658+
if deterministic is not None:
659+
self.deterministic = deterministic
660+
661+
if decode is not None:
662+
self.decode = decode
663+
if (
664+
not hasattr(self, 'cached_key')
665+
or not hasattr(self, 'cached_value')
666+
or not hasattr(self, 'cache_index')
667+
):
668+
if batch_size is None:
669+
raise TypeError(
670+
"'batch_size' must be provided when initializing cache."
671+
)
672+
if max_length is None:
673+
raise TypeError(
674+
"'max_length' must be provided when initializing cache."
675+
)
676+
if isinstance(batch_size, int):
677+
batch_size = (batch_size,)
678+
679+
# initialize cache
680+
cache_shape = (*batch_size, max_length, self.num_heads, self.head_dim)
681+
self.cached_key = nnx.Cache(jnp.zeros(cache_shape, self.dtype))
682+
self.cached_value = nnx.Cache(jnp.zeros(cache_shape, self.dtype))
683+
self.cache_index = nnx.Cache(jnp.array(0, dtype=jnp.int32))
684+
return kwargs
685+
641686

642687
# mask-making utility functions
643688

flax/nnx/nn/normalization.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,21 @@ def __call__(
392392
self.epsilon,
393393
)
394394

395+
def set_mode(
396+
self,
397+
use_running_average: bool | None = None,
398+
**kwargs,
399+
) -> dict:
400+
"""Class method used by ``nnx.set_mode``.
401+
402+
Args:
403+
use_running_average: if True, the stored batch statistics will be
404+
used instead of computing the batch statistics on the input.
405+
"""
406+
if use_running_average is not None:
407+
self.use_running_average = use_running_average
408+
return kwargs
409+
395410

396411
class LayerNorm(Module):
397412
"""Layer normalization (https://arxiv.org/abs/1607.06450).

flax/nnx/nn/stochastic.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,3 +153,17 @@ def __call__(
153153
mask = random.bernoulli(key, p=keep_prob, shape=broadcast_shape)
154154
mask = jnp.broadcast_to(mask, inputs.shape)
155155
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
156+
157+
def set_mode(
158+
self,
159+
deterministic: bool | None = None,
160+
**kwargs,
161+
) -> dict:
162+
"""Class method used by ``nnx.set_mode``.
163+
164+
Args:
165+
deterministic: if True, disables dropout masking.
166+
"""
167+
if deterministic is not None:
168+
self.deterministic = deterministic
169+
return kwargs

0 commit comments

Comments
 (0)