Skip to content

Intermediate value capture API via JAX's hijax.Box #4924

@IvyZX

Description

@IvyZX

JAX has this new hijax.Box mechanism that can insert arbitrary values during forward & backward passes. This can be a good alternative to the current sow and perturb APIs on Flax level.

The feature is not complete yet (e.g., not yet work with vmap/scan, and in some other corner cases). Gonna do some prototying first.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions