You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to use orbax to checkpoint arbitrary penzai models, which have parameters of type pz.ParameterValue, which themselves contain pz.NamedArray instances. These instances contain named axes. So, I thought I'd try my hand at implementing class derived from type_handlers.TypeHandler. However, I can't seem to see where in this workflow the axis names would be stored. I saw there is a TypeHandler.metadata method, but that seems to be called only during restore. And, TypeHandler.serialize doesn't seem to provide an opportunity to specialize except at a very low level, at the tensorstore level.
Am I missing something else? It would be nice to be able to use orbax with penzai models.
On the penzai side, they claim that orbax can be used, but there is no example of saving / loading an arbitrary model.