Skip to content
Discussion options

You must be logged in to vote

@renecotyfanboy it is a bit unclear the purpose of

nnx.call((graph_def, state_1, state_2)) 

with multiple states.

Split and merge paradigm is correct in your example. nnx.call behind the scene is doing that:

flax/flax/nnx/graph.py

Lines 2861 to 2867 in f73aea5

def pure_caller(accessor: DelayedAccessor, *args, **kwargs):
node = merge(*graphdef_state)
method = accessor(node)
out = method(*args, **kwargs)
return out, split(node)
return CallableProxy(pure_caller) # type: ignore

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by renecotyfanboy
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants