Replies: 2 comments
-
Beta Was this translation helpful? Give feedback.
0 replies
-
class SafeJax:
@staticmethod
def serialize(model: nnx.Module):
fn = lambda k: ".".join(map(str, k)) # noqa: E731
params = nnx.to_flat_state(nnx.state(model))
sdict = {fn(k): v for k, v in params}
return safetensors.flax.save(sdict)
@staticmethod
def deserialize(model: nnx.Module, data: bytes):
fn = lambda k: tuple( # noqa: E731
int(x) if x.isdigit() else x
for x in k.split(".")
)
sdict = safetensors.flax.load(data)
params = nnx.from_flat_state(
[
(fn(k), v)
for k, v in sdict.items()
]
)
# TODO: Assert params matches model
nnx.update(model, params) Any pitfalls in this approach? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Using just jax/flax/safetensors?
Beta Was this translation helpful? Give feedback.
All reactions