JAX release v0.2.27
-
Breaking changes:
- Support for NumPy 1.18 has been dropped, per the [deprecation policy](https://jax.readthedocs.io/en/latest/ deprecation.html). Please upgrade to a supported NumPy version.
- The host_callback primitives have been simplified to drop the special autodiff handling for hcb.id_tap and id_print. From now on, only the primals are tapped. The old behavior can be obtained (for a limited time) by setting the
JAX_HOST_CALLBACK_AD_TRANSFORMSenvironment variable, or the--flax_host_callback_ad_transformsflag. Additionally, added documentation for how to implement the old behavior using JAX custom AD APIs ({jax-issue}#8678). - Sorting now matches the behavior of NumPy for
0.0andNaNregardless of the bit representation. In particular,0.0and-0.0are now treated as equivalent, where previously-0.0was treated as less than0.0. Additionally allNaNrepresentations are now treated as equivalent and sorted to the end of the array. Previously negativeNaNvalues were sorted to the front of the array, andNaNvalues with different internal bit representations were not treated as equivalent, and were sorted according to those bit patterns ({jax- issue}#9178). - {func}
jax.numpy.uniquenow treatsNaNvalues in the same way asnp.uniquein NumPy versions 1.21 and newer: at most oneNaNvalue will appear in the uniquified output ({jax-issue}9184).
-
Bug fixes:
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
#8907).
- host_callback now supports ad_checkpoint.checkpoint ({jax-issue}
-
New features:
- add
jax.block_until_ready({jax-issue}`#8941) - Added a new debugging flag/environment variable
JAX_DUMP_IR_TO=/path. If set, JAX dumps the MHLO/HLO IR it generates for each computation to a file under the given path. - Added
jax.ensure_compile_time_evalto the public api ({jax-issue}#7987). - jax2tf now supports a flag jax2tf_associative_scan_reductions to change the lowering for associative reductions, e.g., jnp.cumsum, to behave like JAX on CPU and GPU (to use an associative scan). See the jax2tf README for more details ({jax-issue}
#9189).
- add