-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Description
(Copied from patrick-kidger/optimistix#158)
Functions that define custom_jvp are not greedily evaluated inside of jax.ensure_compile_time_eval(), except when jit'd:
import jax
import jax.numpy as jnp
### ----- modified exerpt from optimistix/_misc.py -----
def _asarray(dtype, x):
return jnp.asarray(x, dtype=dtype)
# remove custom jvp to resolve issue
_asarray = jax.custom_jvp(_asarray, nondiff_argnums=(0,))
@_asarray.defjvp
def _asarray_jvp(dtype, x, tx):
(x,) = x
(tx,) = tx
return _asarray(dtype, x), _asarray(dtype, tx)
def inexact_asarray(x):
dtype = jnp.result_type(x)
if not jnp.issubdtype(jnp.result_type(x), jnp.inexact):
dtype = jnp.float32
return _asarray(dtype, x)
### ----- end -----
@jax.jit
def const_eval(y0):
with jax.ensure_compile_time_eval():
y_stub = jax.tree.map(lambda a: jnp.zeros_like(a), y0)
print(y_stub) # 0
# not greedily evaluated
out = jax.tree.map(inexact_asarray, y_stub)
print(out) # JitTracer<float32[]>
# works as expected when jit'd
out = jax.jit(lambda a: jax.tree.map(inexact_asarray, a))(y_stub)
print(out) # 0.0
const_eval(1)System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.0
jaxlib: 0.7.0
numpy: 2.2.1
python: 3.13.1 | packaged by conda-forge | (main, Dec 5 2024, 21:09:18) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='staff-net-etx-1774.intern.ethz.ch', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:53:27 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6041', machine='arm64')
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working