Using nnx.call with multiple states #4792
-
|
Is there any syntax to use nnx.call with multiple states? I don't see ways to do : in the docs, and it doesn't work anyway. I am currently doing the following within a jit compiled function: Can I stay with this syntax, is there any overhead in using merge instead of call? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
|
@renecotyfanboy it is a bit unclear the purpose of nnx.call((graph_def, state_1, state_2)) with multiple states. Split and merge paradigm is correct in your example. Lines 2861 to 2867 in f73aea5 |
Beta Was this translation helpful? Give feedback.
-
|
@renecotyfanboy |
Beta Was this translation helpful? Give feedback.
@renecotyfanboy it is a bit unclear the purpose of
with multiple states.
Split and merge paradigm is correct in your example.
nnx.callbehind the scene is doing that:flax/flax/nnx/graph.py
Lines 2861 to 2867 in f73aea5