You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).