Releases: jax-ml/jax
JAX v0.8.1
-
New features:
jax.jitnow supports the decorator factory pattern; i.e instead of
writingyou may write@functools.partial(jax.jit, static_argnames=['n']) def f(x, n): ...
@jax.jit(static_argnames=['n']) def f(x, n): ...
-
Changes:
-
jax.lax.linalg.eighnow accepts animplementationargument to
select between QR (CPU/GPU), Jacobi (GPU/TPU), and QDWH (TPU)
implementations. TheEighImplementationenum is publicly exported from
jax.lax.linalg. -
jax.lax.linalg.svdnow implements analgorithmthat uses the polar
decomposition on CUDA GPUs. This is also an alias for the existing algorithm
on TPUs.
-
-
Bug fixes:
- Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
GPU (#33062).
- Fixed a bug introduced in JAX 0.7.2 where eigh failed for large matrices on
-
Deprecations:
jax.sharding.PmapShardingis now deprecated. Please use
jax.NamedShardinginstead.jx.device_put_replicatedis now deprecated. Please usejax.device_put
with the appropriate sharding instead.jax.device_put_shardedis now deprecated. Please usejax.device_putwith
the appropriate sharding instead.- Default
axis_typesofjax.make_meshwill change in JAX v0.9.0 to return
jax.sharding.AxisType.Explicit. Leaving axis_types unspecified will raise a
DeprecationWarning. jax.cloud_tpu_initand its contents were deprecated. There is no reason for a user to import or use the contents of this module; JAX handles this for you automatically if needed.
JAX v0.8.0
-
Breaking changes:
- JAX is changing the default
jax.pmapimplementation to one implemented in
terms ofjax.jitandjax.shard_map.jax.pmapis in maintenance mode
and we encourage all new code to usejax.shard_mapdirectly. See the
migration guide for
more information. - The
auto=parameter ofjax.experimental.shard_map.shard_maphas been
removed. This means thatjax.experimental.shard_map.shard_mapno longer
supports nesting. If you want to nest shard_map calls, please use
jax.shard_map. - JAX no longer allows passing objects that support
__jax_array__directly
to, e.g.jit-ed functions. Calljax.numpy.asarrayon them first. jax.numpy.covis now returns NaN for empty arrays ({jax-issue}#32305),
and matches NumPy 2.2 behavior for single-row design matrices ({jax-issue}#32308).- JAX no longer accepts
Arrayvalues where adtypevalue is expected. Call
.dtypeon these values first. - The deprecated function
jax.interpreters.mlir.custom_callwas
removed. - The
jax.util,jax.extend.ffi, andjax.experimental.host_callback
modules have been removed. All public APIs within these modules were
deprecated and removed in v0.7.0 or earlier. - The deprecated symbol
jax.custom_derivatives.custom_jvp_call_jaxpr_p
was removed. jax.experimental.multihost_utils.process_allgatherraises an error when
the input is a jax.Array and not fully-addressable andtiled=False. To fix
this, passtiled=Trueto yourprocess_allgatherinvocation.- from
jax.experimental.compilation_cache, the deprecated symbols
is_initializedandinitialize_cachewere removed. - The deprecated function
jax.interpreters.xla.canonicalize_dtype
was removed. jaxlib.hlo_helpershas been removed. Usejax.ffiinstead.- The option
jax_cpu_enable_gloo_collectiveshas been removed. Use
jax_cpu_collectives_implementationinstead. - The previously-deprecated
interpolationargument to
jax.numpy.percentileandjax.numpy.quantilehas been
removed; usemethodinstead. - The JAX-internal
for_loopprimitive was removed. Its functionality,
reading from and writing to refs in the loop body, is now directly
supported byjax.lax.fori_loop. If you need help updating your
code, please file a bug. jax.numpy.trimzerosnow errors for non-1D input.- The
whereargument tojax.numpy.sumand other reductions is now
required to be boolean. Non-boolean values have resulted in a
DeprecationWarningsince JAX v0.5.0. - The deprecated functions in
jax.dlpack,jax.errors,
jax.lib.xla_bridge,jax.lib.xla_client, and
jax.lib.xla_extensionwere removed. jax.interpreters.mlir.dense_bool_arraywas removed. Use MLIR APIs to
construct attributes instead.
- JAX is changing the default
-
Changes
jax.numpy.linalg.eignow returns a namedtuple (with attributes
eigenvaluesandeigenvectors) instead of a plain tuple.jax.gradandjax.vjpwill now round always primals to
float32iffloat64mode is not enabled.jax.dlpack.from_dlpacknow accepts arrays with non-default layouts,
for example, transposed.- The default nonsymmetric eigendecomposition on NVIDIA GPUs now uses
cusolver. The magma and LAPACK implementations are still available via the
newimplementationargument tojax.lax.linalg.eig
({jax-issue}#27265). Theuse_magmaargument is now deprecated in favor
ofimplementation. jax.numpy.trim_zerosnow follows NumPy 2.2 in supporting
multi-dimensional inputs.
-
Deprecations
jax.experimental.enable_x64andjax.experimental.disable_x64
are deprecated in favor of the new non-experimental context manager
jax.enable_x64.jax.experimental.shard_map.shard_mapis deprecated; going forward use
jax.shard_map.jax.experimental.pjit.pjitis deprecated; going forward use
jax.jit.
JAX v0.7.2
-
Breaking changes:
jax.dlpack.from_dlpackno longer accepts a DLPack capsule. This
behavior was deprecated and is now removed. The function must be called
with an array implementing__dlpack__and__dlpack_device__.
-
Changes
-
The minimum supported NumPy version is now 2.0. Since SciPy 1.13 is required
for NumPy 2.0 support, the minimum supported SciPy version is now 1.13. -
JAX now represents constants in its internal jaxpr representation as a
LiteralArray, which is a private JAX type that duck types as a
numpy.ndarray. This type may be exposed to users viacustom_jvprules,
for example, and may break code that usesisinstance(x, np.ndarray). If
this breaks your code, you may convert these arrays to classic NumPy arrays
usingnp.asarray(x).
-
-
Bug fixes
arr.view(dtype=None)now returns the array unchanged, matching NumPy's
semantics. Previously it returned the array with a float dtype.jax.random.randintnow produces a less-biased distribution for 8-bit and
16-bit integer types ({jax-issue}#27742). To restore the previous biased
behavior, you may temporarily set thejax_safer_randintconfiguration to
False, but note this is a temporary config that will be removed in a
future release.
-
Deprecations:
- The parameters
enable_xlaandnative_serializationforjax2tf.convert
are deprecated and will be removed in a future version of JAX. These were
used for jax2tf with non-native serialization, which has been now removed. - Setting the config state
jax_pmap_no_rank_reductiontoFalseis
deprecated. By default,jax_pmap_no_rank_reductionwill be set toTrue
andjax.pmapshards will not have their rank reduced, keeping the same
rank as their enclosing array.
- The parameters
JAX v0.7.1
-
New features
- JAX now ships Python 3.14 and 3.14t wheels.
- JAX now ships Python 3.13t and 3.14t wheels on Mac. Previously we only
offered free-threading builds on Linux.
-
Changes
- Exposed
jax.set_meshwhich acts as a global setter and a context manager.
Removedjax.sharding.use_meshin favor ofjax.set_mesh. - JAX is now built using CUDA 12.9. All versions of CUDA 12.1 or newer remain
supported. jax.lax.dotnow implements the general dot product via the optional
dimension_numbersargument.
- Exposed
-
Deprecations:
jax.lax.zeros_like_arrayis deprecated. Please use
jax.numpy.zeros_likeinstead.- Attempting to import
jax.experimental.host_callbacknow results in
aDeprecationWarning, and will result in anImportErrorstarting in JAX
v0.8.0. Its APIs have raisedNotImplementedErrorsince JAX version 0.4.35. - In
jax.lax.dot, passing theprecisionandpreferred_element_type
arguments by position is deprecated. Pass them by explicit keyword instead. - Several dozen internal APIs have been deprecated from
jax.interpreters.ad,
jax.interpreters.batching, andjax.interpreters.partial_eval; they
are used rarely if ever outside JAX itself, and most are deprecated without any
public replacement.
JAX v0.7.0
-
New features:
- Added
jax.Pwhich is an alias forjax.sharding.PartitionSpec. - Added
jax.tree.reduce_associative.
- Added
-
Breaking changes:
- JAX is migrating from GSPMD to Shardy by default. See the
migration guide
for more information. - JAX autodiff is switching to using direct linearization by default (instead of
implementing linearization via JVP and partial eval).
See migration guide
for more information. jax.stages.OutInfohas been replaced withjax.ShapeDtypeStruct.jax.jitnow requiresfunto be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in an error
starting in v0.7.x. This raised a DeprecationWarning in v0.6.x.- The minimum Python version is now 3.11. 3.11 will remain the minimum
supported version until July 2026. - Layout API renames:
Layout,.layout,.input_layoutsand.output_layoutshave been
renamed toFormat,.format,.input_formatsand.output_formatsDeviceLocalLayout,.device_local_layouthave been renamed toLayout
and.layout
jax.experimental.shardmodule has been deleted and all the APIs have been
moved to thejax.shardingendpoint. So usejax.sharding.reshard,
jax.sharding.auto_axesandjax.sharding.explicit_axesinstead of their
experimental endpoints.lax.infeedandlax.outfeedwere removed, after being deprecated in
JAX 0.6. Thetransfer_to_infeedandtransfer_from_outfeedmethods were
also removed theDeviceobjects.- The
jax.extend.core.primitives.pjit_pprimitive has been renamed to
jit_p, and itsnameattribute has changed from"pjit"to"jit".
This affects the string representations of jaxprs. The same primitive is no
longer exported from thejax.experimental.pjitmodule. - The (undocumented) function
jax.extend.backend.add_clear_backends_callback
has been removed. Users should usejax.extend.backend.register_backend_cache
instead.
- JAX is migrating from GSPMD to Shardy by default. See the
-
Deprecations:
- {obj}
jax.dlpack.SUPPORTED_DTYPESis deprecated; please use the new
jax.dlpack.is_supported_dtypefunction. jax.scipy.special.sph_harmhas been deprecated following a similar
deprecation in SciPy; usejax.scipy.special.sph_harm_yinstead.- From {mod}
jax.interpreters.xla, the previously deprecated symbols
abstractifyandpytype_aval_mappingshave been removed. jax.interpreters.xla.canonicalize_dtypeis deprecated. For
canonicalizing dtypes, preferjax.dtypes.canonicalize_dtype.
For checking whether an object is a valid jax input, prefer
jax.core.valid_jaxtype.- From {mod}
jax.core, the previously deprecated symbolsAxisName,
ConcretizationTypeError,axis_frame,call_p,closed_call_p,
get_type,trace_state_clean,typematch, andtypecheckhave been
removed. - From {mod}
jax.lib.xla_client, the previously deprecated symbols
DeviceAssignment,get_topology_for_devices, andmlir_api_version
have been removed. jax.extend.ffiwas removed after being deprecated in v0.5.0.
Use {mod}jax.ffiinstead.jax.lib.xla_bridge.get_compile_optionsis deprecated, and replaced by
jax.extend.backend.get_compile_options.
- {obj}
JAX v0.6.2
-
New features:
- Added
jax.tree.broadcastwhich implements a pytree prefix broadcasting helper.
- Added
-
Changes
- The minimum NumPy version is 1.26 and the minimum SciPy version is 1.12.
JAX v0.6.1
-
New features:
- Added
jax.lax.axis_sizewhich returns the size of the mapped axis
given its name.
- Added
-
Changes
- Additional checking for the versions of CUDA package dependencies was
reenabled, having been accidentally disabled in a previous release. - JAX nightly packages are now published to artifact registry. To install
these packages, see the JAX installation guide. jax.sharding.PartitionSpecno longer inherits from a tuple.jax.ShapeDtypeStructis immutable now. Please use.updatemethod to
update yourShapeDtypeStructinstead of doing in-place updates.
- Additional checking for the versions of CUDA package dependencies was
-
Deprecations
jax.custom_derivatives.custom_jvp_call_jaxpr_pis deprecated, and will be
removed in JAX v0.7.0.
JAX v0.6.0
-
Breaking changes
jax.numpy.arrayno longer acceptsNone. This behavior was
deprecated since November 2023 and is now removed.- Removed the
config.jax_data_dependent_tracing_fallbackconfig option,
which was added temporarily in v0.4.36 to allow users to opt out of the
new "stackless" tracing machinery. - Removed the
config.jax_eager_pmapconfig option. - Disallow the calling of
lowerandtraceAOT APIs on the result
ofjax.jitif there have been subsequent wrappers applied.
Previously this worked, but silently ignored the wrappers.
The workaround is to applyjax.jitlast among the wrappers,
and similarly forjax.pmap.
See#27873. - The
cuda12_pipextra forjaxhas been removed; usepip install jax[cuda12]
instead.
-
Changes
- The minimum CuDNN version is v9.8.
- JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
supported. - JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously usingpip install jax[cuda12_local]
to install JAX, runpip install jax[cuda12-local]instead. jax.jitnow requiresfunto be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
-
Deprecations
jax.tree_util.build_treeis deprecated. Usejax.tree.unflatten
instead.- Implemented host callback handlers for CPU and GPU devices using XLA's FFI
and removed existing CPU/GPU handlers using XLA's custom call. - All APIs in
jax.lib.xla_extensionare now deprecated. jax.interpreters.mlir.hloandjax.interpreters.mlir.func_dialect,
which were accidental exports, have been removed. If needed, they are
available fromjax.extend.mlir.jax.interpreters.mlir.custom_callis deprecated. The APIs provided by
jax.ffishould be used instead.- The deprecated use of
jax.ffi.ffi_callwith inline arguments is no
longer supported.jax.ffi.ffi_callnow unconditionally returns a
callable. - The following exports in
jax.lib.xla_clientare deprecated:
get_topology_for_devices,heap_profile,mlir_api_version,Client,
CompileOptions,DeviceAssignment,Frame,HloSharding,OpSharding,
Traceback. - The following internal APIs in
jax.utilare deprecated:
HashableFunction,as_hashable_function,cache,safe_map,safe_zip,
split_dict,split_list,split_list_checked,split_merge,subvals,
toposort,unzip2,wrap_name, andwraps. jax.dlpack.to_dlpackhas been deprecated. You can usually pass a JAX
Arraydirectly to thefrom_dlpackfunction of another framework. If you
need the functionality ofto_dlpack, use the__dlpack__attribute of an
array.jax.lax.infeed,jax.lax.infeed_p,jax.lax.outfeed, and
jax.lax.outfeed_pare deprecated and will be removed in JAX v0.7.0.- Several previously-deprecated APIs have been removed, including:
- From
jax.lib.xla_client:ArrayImpl,FftType,PaddingType,
PrimitiveType,XlaBuilder,dtype_to_etype,
ops,register_custom_call_target,shape_from_pyval,Shape,
XlaComputation. - From
jax.lib.xla_extension:ArrayImpl,XlaRuntimeError. - From
jax:jax.treedef_is_leaf,jax.tree_flatten,jax.tree_map,
jax.tree_leaves,jax.tree_structure,jax.tree_transpose, and
jax.tree_unflatten. Replacements can be found injax.treeor
jax.tree_util. - From
jax.core:AxisSize,ClosedJaxpr,EvalTrace,InDBIdx,InputType,
Jaxpr,JaxprEqn,Literal,MapPrimitive,OpaqueTraceState,OutDBIdx,
Primitive,Token,TRACER_LEAK_DEBUGGER_WARNING,Var,concrete_aval,
dedup_referents,escaped_tracer_error,extend_axis_env_nd,full_lower,get_referent,jaxpr_as_fun,join_effects,lattice_join,
leaked_tracer_error,maybe_find_leaked_tracers,raise_to_shaped,
raise_to_shaped_mappings,reset_trace_state,str_eqn_compact,
substitute_vars_in_output_ty,typecompat, andused_axis_names_jaxpr. Most
have no public replacement, though a few are available atjax.extend.core. - The
vectorizedargument tojax.pure_callbackand
jax.ffi.ffi_call. Use thevmap_methodparameter instead.
- From
JAX v0.5.3
-
New Features
- Added a
allow_negative_indicesoption tojax.lax.dynamic_slice,
jax.lax.dynamic_update_sliceand related functions. The default is
true, matching the current behavior. If set to false, JAX does not need to
emit code clamping negative indices, which improves code size. - Added a
replaceoption tojax.random.categoricalto enable sampling
without replacement.
- Added a
JAX v0.5.2
Patch release of 0.5.1
- Bug fixes
- Fixes TPU metric logging and
tpu-info, which was broken in 0.5.1
- Fixes TPU metric logging and