Skip to content

Commit bebe984

Browse files
author
jax authors
committed
Merge pull request #9205 from jakevdp:einsum-tuple
PiperOrigin-RevId: 422013671
2 parents b92db58 + 77d60cf commit bebe984

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

jax/_src/numpy/lax_numpy.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4935,7 +4935,13 @@ def tensordot(a, b, axes=2, *, precision=None):
49354935
precision=precision)
49364936

49374937

4938-
@_wraps(np.einsum, lax_description=_PRECISION_DOC, skip_params=['out'])
4938+
_EINSUM_DOC = _PRECISION_DOC + """\
4939+
A tuple ``precision`` does not necessarily map to multiple arguments of ``einsum()``;
4940+
rather, the specified ``precision`` is forwarded to each ``dot_general`` call used in
4941+
the implementation.
4942+
"""
4943+
4944+
@_wraps(np.einsum, lax_description=_EINSUM_DOC, skip_params=['out'])
49394945
def einsum(*operands, out=None, optimize='optimal', precision=None,
49404946
_use_xeinsum=False):
49414947
if out is not None:

0 commit comments

Comments
 (0)