Releases: jax-ml/jax
Releases · jax-ml/jax
JAX release v0.3.8
- GitHub commits.
- Changes
- {func}
jax.numpy.linalg.svdon TPUs uses a qdwh-svd solver. - {func}
jax.numpy.linalg.condon TPUs now accepts complex input. - {func}
jax.numpy.linalg.pinvon TPUs now accepts complex input. - {func}
jax.numpy.linalg.matrix_rankon TPUs now accepts complex input. - {func}
jax.scipy.cluster.vq.vqhas been added. jax.experimental.maps.meshhas been deleted.
Please usejax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.- {func}
jax.scipy.linalg.qrnow returns a length-1 tuple rather than the raw array whenmode='r', in order to match the behavior ofscipy.linalg.qr({jax-issue}#10452) - {func}
jax.numpy.take_along_axisnow takes an optionalmodeparameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip". - {func}
jax.numpy.takenow defaults tomode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices. - Scatter operations, such as
x.at[...].set(...), now have"drop"semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct. - {func}
jax.numpy.take_along_axisnow raises aTypeErrorif its indices are not of an integer type, matching the behavior of
{func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers. - {func}
jax.numpy.ravel_multi_indexnow raises aTypeErrorif itsdimsargument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integerdimswas silently cast to integers. - {func}
jax.numpy.splitnow raises aTypeErrorif itsaxisargument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integeraxiswas silently cast to integers. - {func}
jax.numpy.indicesnow raises aTypeErrorif its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers. - {func}
jax.numpy.diagnow raises aTypeErrorif itskargument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integerkwas silently cast to integers. - Added {func}
jax.random.orthogonal.
- {func}
- Deprecations
- Many functions and objects available in {mod}
jax.test_utilare now deprecated and will raise a warning on import. This includescases_from_list,check_close,check_eq,device_under_test,format_shape_dtype_string,rand_uniform,skip_on_devices,with_config,xla_bridge, and_default_tolerance({jax-issue}#10389). These, along with previously-deprecatedJaxTestCase,JaxTestLoader, andBufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.
- Many functions and objects available in {mod}
Jaxlib v0.3.7
- Linux wheels are now built conforming to the
manylinux2014standard, instead ofmanylinux2010.
JAX release v0.3.7
- Fixed a performance problem if the indices passed to
jax.numpy.take_along_axiswere broadcasted (#10281). jax.scipy.special.expitandjax.scipy.special.logitnow require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.- The
DeviceArray.tile()method is deprecated, because numpy arrays do not have atile()method. As a replacement for this, use jax.numpy.tile (#10266).
JAX release v0.3.6
- Changes:
- Upgraded libtpu wheel to the fixed version. Fixes #10218.
JAX release v0.3.5
Changes
- added
jax.random.loggamma& improved behavior ofjax.random.beta
andjax.random.dirichletfor small parameter values (#9906). - the private
lax_numpysubmodule is no longer exposed in thejax.numpynamespace (#10029). - added array creation routines
jax.numpy.frombuffer,jax.numpy.fromfunction,
andjax.numpy.fromstring(#10049). DeviceArray.copy()now returns aDeviceArrayrather than anp.ndarray(#10069)- added
jax.scipy.linalg.rsf2csf - Deprecations:
JAX release v0.3.4
Fix a bug introduced in #9923.
JAX release v0.3.3
Jax release v0.3.1
- Changes:
jax.test_util.JaxTestCaseandjax.test_util.JaxTestLoaderare now deprecated.
The suggested replacement is to useparametrized.TestCasedirectly. For tests that
rely on custom asserts such asJaxTestCase.assertAllClose(), the suggested replacement
is to use standard numpy testing utilities such asnumpy.testing.assert_allclose(),
which work directly with JAX arrays (#9620 ).jax.test_util.JaxTestCasenow setsjax_numpy_rank_promotion='raise'by default
(#9562 ). To recover the previous behavior, use the new
jax.test_util.with_configdecorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
- Added
jax.scipy.linalg.schur,jax.scipy.linalg.sqrtm,
jax.scipy.signal.csd,jax.scipy.signal.stft,
jax.scipy.signal.welch.
Jaxlib release v0.3.0
- Changes
- Bazel 5.0.0 is now required to build jaxlib.
- jaxlib version has been bumped to 0.3.0. Please see the design doc
for the explanation.
Jax release v0.3.0
- Changes
- jax version has been bumped to 0.3.0. Please see the design doc
for the explanation.
- jax version has been bumped to 0.3.0. Please see the design doc