@@ -427,6 +427,138 @@ def eval(self, **attributes):
427427 raise_if_not_found = False ,
428428 )
429429
430+ def set_mode (node : A , / , * , only : filterlib .Filter = ..., raise_if_not_found : bool = True , ** kwargs ) -> A :
431+ """Creates a new node with static attributes updated according to ``**kwargs``.
432+
433+ The new node contains references to jax arrays in the original node. If a
434+ kwarg is not found in any module, this method raises a ValueError. ``set_mode``
435+ class methods should return any unused kwargs.
436+
437+ Example::
438+ >>> from flax import nnx
439+ ...
440+ >>> class Block(nnx.Module):
441+ ... def __init__(self, din, dout, *, rngs: nnx.Rngs):
442+ ... self.linear = nnx.Linear(din, dout, rngs=rngs)
443+ ... self.dropout = nnx.Dropout(0.5, deterministic=False)
444+ ... self.batch_norm = nnx.BatchNorm(10, use_running_average=False, rngs=rngs)
445+ ...
446+ >>> block = Block(2, 5, rngs=nnx.Rngs(0))
447+ >>> block.dropout.deterministic, block.batch_norm.use_running_average
448+ (False, False)
449+ >>> new_block = nnx.set_mode(block, deterministic=True, use_running_average=True)
450+ >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
451+ (True, True)
452+
453+ ``Filter``'s can be used to set the attributes of specific Modules::
454+ >>> block = Block(2, 5, rngs=nnx.Rngs(0))
455+ >>> new_block = nnx.set_mode(block, only=nnx.Dropout, deterministic=True)
456+ >>> # Only the dropout will be modified
457+ >>> new_block.dropout.deterministic, new_block.batch_norm.use_running_average
458+ (True, False)
459+
460+ Args:
461+ node: the object to create a copy of.
462+ only: Filters to select the Modules to set the attributes of.
463+ **kwargs: The attributes to set.
464+ """
465+ predicate = filterlib .to_predicate (only )
466+
467+ counts = {k : 0 for k in kwargs }
468+ counts ["_set_mode_calls" ] = 0
469+
470+ def _set_mode_fn (path , node ):
471+ if hasattr (node , 'set_mode' ) and predicate (path , node ):
472+ counts ["_set_mode_calls" ] += 1
473+ unused = node .set_mode (** kwargs )
474+ for k in unused :
475+ counts [k ] += 1
476+ return node
477+
478+ out = graph .recursive_map (_set_mode_fn , node )
479+
480+ if raise_if_not_found :
481+ set_mode_calls = counts .pop ("_set_mode_calls" )
482+ unused_keys = [k for k , v in counts .items () if v == set_mode_calls ]
483+ if unused_keys :
484+ raise ValueError (f"Unused keys found in set_mode: { unused_keys } " )
485+
486+ return out
487+
488+ def train_mode (node : A , / , * , only : filterlib .Filter = ..., ** kwargs ) -> A :
489+ """Creates a new node set to training mode.
490+
491+ ``train_mode`` uses ``set_mode`` to recursively set attributes ``deterministic=False``
492+ and ``use_running_average=False`` of all nested Modules that have these attributes.
493+ Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
494+ Modules.
495+
496+ Example::
497+ >>> from flax import nnx
498+ ...
499+ >>> class Block(nnx.Module):
500+ ... def __init__(self, din, dout, *, rngs: nnx.Rngs):
501+ ... self.linear = nnx.Linear(din, dout, rngs=rngs)
502+ ... # initialize Dropout and BatchNorm in eval mode
503+ ... self.dropout = nnx.Dropout(0.5, deterministic=True)
504+ ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs)
505+ ...
506+ >>> block = Block(2, 5, rngs=nnx.Rngs(0))
507+ >>> block.dropout.deterministic, block.batch_norm.use_running_average
508+ (True, True)
509+ >>> train_block = nnx.train_mode(block)
510+ >>> train_block.dropout.deterministic, train_block.batch_norm.use_running_average
511+ (False, False)
512+
513+ Args:
514+ **kwargs: additional attributes passed to ``set_attributes``.
515+ """
516+ return set_mode (
517+ node ,
518+ only = only ,
519+ raise_if_not_found = False ,
520+ deterministic = False ,
521+ use_running_average = False ,
522+ ** kwargs ,
523+ )
524+
525+ def eval_mode (node : A , / , * , only : filterlib .Filter = ..., ** kwargs ) -> A :
526+ """Creates a new node set to evaluation mode.
527+
528+ ``eval_mode`` uses ``set_mode`` to recursively set attributes ``deterministic=True``
529+ and ``use_running_average=True`` of all nested Modules that have these attributes.
530+ Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm``
531+ Modules.
532+
533+ Example::
534+ >>> from flax import nnx
535+ ...
536+ >>> class Block(nnx.Module):
537+ ... def __init__(self, din, dout, *, rngs: nnx.Rngs):
538+ ... self.linear = nnx.Linear(din, dout, rngs=rngs)
539+ ... self.dropout = nnx.Dropout(0.5)
540+ ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs)
541+ ...
542+ >>> block = Block(2, 5, rngs=nnx.Rngs(0))
543+ >>> block.dropout.deterministic, block.batch_norm.use_running_average
544+ (False, False)
545+ >>> eval_block = nnx.eval_mode(block)
546+ >>> eval_block.dropout.deterministic, eval_block.batch_norm.use_running_average
547+ (True, True)
548+
549+ Args:
550+ **kwargs: additional attributes passed to ``set_mode``.
551+ """
552+ return set_mode (
553+ node ,
554+ only = only ,
555+ raise_if_not_found = False ,
556+ deterministic = True ,
557+ use_running_average = True ,
558+ ** kwargs ,
559+ )
560+
561+
430562
431563def first_from (* args : tp .Optional [A ], error_msg : str ) -> A :
432564 """Return the first non-None argument.
0 commit comments