Skip to content

Greedy evaluation of custom_jvp fails when invoked in jax.ensure_compile_time_eval() #30787

@otoomey

Description

@otoomey

Description

(Copied from patrick-kidger/optimistix#158)

Functions that define custom_jvp are not greedily evaluated inside of jax.ensure_compile_time_eval(), except when jit'd:

import jax
import jax.numpy as jnp

### ----- modified exerpt from optimistix/_misc.py -----

def _asarray(dtype, x):
    return jnp.asarray(x, dtype=dtype)

# remove custom jvp to resolve issue
_asarray = jax.custom_jvp(_asarray, nondiff_argnums=(0,))

@_asarray.defjvp
def _asarray_jvp(dtype, x, tx):
    (x,) = x
    (tx,) = tx
    return _asarray(dtype, x), _asarray(dtype, tx)

def inexact_asarray(x):
    dtype = jnp.result_type(x)
    if not jnp.issubdtype(jnp.result_type(x), jnp.inexact):
        dtype = jnp.float32
    return _asarray(dtype, x)

### ----- end -----

@jax.jit
def const_eval(y0):
    with jax.ensure_compile_time_eval():
        y_stub = jax.tree.map(lambda a: jnp.zeros_like(a), y0)
        print(y_stub) # 0

        # not greedily evaluated
        out = jax.tree.map(inexact_asarray, y_stub)
        print(out) # JitTracer<float32[]>

        # works as expected when jit'd
        out = jax.jit(lambda a: jax.tree.map(inexact_asarray, a))(y_stub)
        print(out) # 0.0

const_eval(1)

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.0
jaxlib: 0.7.0
numpy:  2.2.1
python: 3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:09:18) [Clang 18.1.8 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='staff-net-etx-1774.intern.ethz.ch', release='24.5.0', version='Darwin Kernel Version 24.5.0: Tue Apr 22 19:53:27 PDT 2025; root:xnu-11417.121.6~2/RELEASE_ARM64_T6041', machine='arm64')

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions