Skip to content
8 changes: 8 additions & 0 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

config.update('jax_enable_x64', True)

from jax.core import Tracer
from jax import Array
import jax.numpy as jnp
from jax.scipy.special import gammaln, xlogy
Expand All @@ -14,6 +15,10 @@
log = logging.getLogger(__name__)


def _currently_jitting():
return isinstance(jnp.array(1) + 1, Tracer)


class _BasicPoisson:
def __init__(self, rate):
self.rate = rate
Expand Down Expand Up @@ -184,6 +189,9 @@ def conditional(self, predicate, true_callable, false_callable):
return true_callable() if predicate else false_callable()

def tolist(self, tensor_in):
if _currently_jitting():
# .aval is the abstract value and has a little nicer representation
return tensor_in.aval
try:
return jnp.asarray(tensor_in).tolist()
except (TypeError, ValueError):
Expand Down
16 changes: 16 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,19 @@ def test_backend_array_type(backend):
def test_tensor_array_types():
# can't really assert the content of them so easily
assert pyhf.tensor.array_types


@pytest.mark.only_jax
def test_jax_data_shape_mismatch_during_jitting(backend):
"""
Validate that during JAX tracingg time the correct form
of the tracer is returned.
Issue: https://github.com/scikit-hep/pyhf/issues/1422
PR: https://github.com/scikit-hep/pyhf/pull/2580
"""
model = pyhf.simplemodels.uncorrelated_background([10], [15], [5])
with pytest.raises(
pyhf.exceptions.InvalidPdfData,
match="eval failed as data has len 1 but 2 was expected",
):
pyhf.infer.mle.fit([12.5], model)