nnx.Sequential and nnx.split Problems #4645
-
Hello, I recently tried less general architectures (compared to custom primitive ones), and I've been playing with nnx.Sequential to avoid writing a nnx.Module sub class for each variation. In the forward mode, everything works like expected, however I have some problem when extracting the model graph and parameter with nnx.split (in order to get a gradient) import jax.numpy as jnp
import jax
from jax.tree_util import Partial
from flax import nnx
x = jnp.zeros( (33,3))
@Partial(jax.jit, inline = True)
def conca(*args):
return jnp.concatenate(args,axis=-1)
myNN = nnx.Sequential( nnx.Linear(3,12, rngs = nnx.Rngs(default = 0) ), Partial( conca, x) )
Y = myNN(x) ; print( Y.shape)
def loss( state, graphdef, rng_state, input):
model = nnx.merge( graphdef, rng_state, state)
return (model(input)**2).mean()
def Gjacob(f):
""" General Gradient of respective parameters (projecting jacobian on canonical directions (tangent valued of ones)
"""
def Gacob(*x):
y, vjp_fn = jax.vjp(f, *x)
return vjp_fn(jnp.ones_like(y)) # tangent values at 1
return Gacob
>>> graphdef, rng_state, state = nnx.split(myNN , nnx.RngState,...)
raise ValueError(
f'Arrays leaves are not supported, at {path_str!r}: {value}'
)
ValueError: Arrays leaves are not supported, at 'layers/1/0/0' Note that this problem does not occur with only Linear(nnx.Module) operators Maybe some moderator know, @cgarciae? Cheers, |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
Hey @DiagRisker, this should be resolved by #4612 soon. The issue is that currently NNX doesn't allow JAX Array as leaves, the solution would be to wrap Partial(conca, nnx.Variable(x)) After #4612 gets merged you'll be able to use JAX Arrays directly. |
Beta Was this translation helpful? Give feedback.
Hey @DiagRisker, this should be resolved by #4612 soon. The issue is that currently NNX doesn't allow JAX Array as leaves, the solution would be to wrap
x
inVariable
, e.g.After #4612 gets merged you'll be able to use JAX Arrays directly.