JAX v0.6.0
-
Breaking changes
jax.numpy.arrayno longer acceptsNone. This behavior was
deprecated since November 2023 and is now removed.- Removed the
config.jax_data_dependent_tracing_fallbackconfig option,
which was added temporarily in v0.4.36 to allow users to opt out of the
new "stackless" tracing machinery. - Removed the
config.jax_eager_pmapconfig option. - Disallow the calling of
lowerandtraceAOT APIs on the result
ofjax.jitif there have been subsequent wrappers applied.
Previously this worked, but silently ignored the wrappers.
The workaround is to applyjax.jitlast among the wrappers,
and similarly forjax.pmap.
See#27873. - The
cuda12_pipextra forjaxhas been removed; usepip install jax[cuda12]
instead.
-
Changes
- The minimum CuDNN version is v9.8.
- JAX is now built using CUDA 12.8. All versions of CUDA 12.1 or newer remain
supported. - JAX package extras are now updated to use dash instead of underscore to
align with PEP 685. For instance, if you were previously usingpip install jax[cuda12_local]
to install JAX, runpip install jax[cuda12-local]instead. jax.jitnow requiresfunto be passed by position, and additional
arguments to be passed by keyword. Doing otherwise will result in a
DeprecationWarning in v0.6.X, and an error in starting in v0.7.X.
-
Deprecations
jax.tree_util.build_treeis deprecated. Usejax.tree.unflatten
instead.- Implemented host callback handlers for CPU and GPU devices using XLA's FFI
and removed existing CPU/GPU handlers using XLA's custom call. - All APIs in
jax.lib.xla_extensionare now deprecated. jax.interpreters.mlir.hloandjax.interpreters.mlir.func_dialect,
which were accidental exports, have been removed. If needed, they are
available fromjax.extend.mlir.jax.interpreters.mlir.custom_callis deprecated. The APIs provided by
jax.ffishould be used instead.- The deprecated use of
jax.ffi.ffi_callwith inline arguments is no
longer supported.jax.ffi.ffi_callnow unconditionally returns a
callable. - The following exports in
jax.lib.xla_clientare deprecated:
get_topology_for_devices,heap_profile,mlir_api_version,Client,
CompileOptions,DeviceAssignment,Frame,HloSharding,OpSharding,
Traceback. - The following internal APIs in
jax.utilare deprecated:
HashableFunction,as_hashable_function,cache,safe_map,safe_zip,
split_dict,split_list,split_list_checked,split_merge,subvals,
toposort,unzip2,wrap_name, andwraps. jax.dlpack.to_dlpackhas been deprecated. You can usually pass a JAX
Arraydirectly to thefrom_dlpackfunction of another framework. If you
need the functionality ofto_dlpack, use the__dlpack__attribute of an
array.jax.lax.infeed,jax.lax.infeed_p,jax.lax.outfeed, and
jax.lax.outfeed_pare deprecated and will be removed in JAX v0.7.0.- Several previously-deprecated APIs have been removed, including:
- From
jax.lib.xla_client:ArrayImpl,FftType,PaddingType,
PrimitiveType,XlaBuilder,dtype_to_etype,
ops,register_custom_call_target,shape_from_pyval,Shape,
XlaComputation. - From
jax.lib.xla_extension:ArrayImpl,XlaRuntimeError. - From
jax:jax.treedef_is_leaf,jax.tree_flatten,jax.tree_map,
jax.tree_leaves,jax.tree_structure,jax.tree_transpose, and
jax.tree_unflatten. Replacements can be found injax.treeor
jax.tree_util. - From
jax.core:AxisSize,ClosedJaxpr,EvalTrace,InDBIdx,InputType,
Jaxpr,JaxprEqn,Literal,MapPrimitive,OpaqueTraceState,OutDBIdx,
Primitive,Token,TRACER_LEAK_DEBUGGER_WARNING,Var,concrete_aval,
dedup_referents,escaped_tracer_error,extend_axis_env_nd,full_lower,get_referent,jaxpr_as_fun,join_effects,lattice_join,
leaked_tracer_error,maybe_find_leaked_tracers,raise_to_shaped,
raise_to_shaped_mappings,reset_trace_state,str_eqn_compact,
substitute_vars_in_output_ty,typecompat, andused_axis_names_jaxpr. Most
have no public replacement, though a few are available atjax.extend.core. - The
vectorizedargument tojax.pure_callbackand
jax.ffi.ffi_call. Use thevmap_methodparameter instead.
- From