Skip to content
1 change: 1 addition & 0 deletions docs/contributors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ Contributors include:
- Lorenz Gaertner
- Melissa Weber Mendonça
- Matthias Bussonnier
- Peter Fackeldey
11 changes: 11 additions & 0 deletions src/pyhf/tensor/jax_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

config.update('jax_enable_x64', True)

from jax.core import Tracer
from jax import Array
import jax.numpy as jnp
from jax.scipy.special import gammaln, xlogy
Expand All @@ -14,6 +15,13 @@
log = logging.getLogger(__name__)


def _currently_jitting():
"""
JAX turns arrays into Tracers during jit-compilation, so check for that.
"""
return isinstance(jnp.array(1), Tracer)


class _BasicPoisson:
def __init__(self, rate):
self.rate = rate
Expand Down Expand Up @@ -184,6 +192,9 @@ def conditional(self, predicate, true_callable, false_callable):
return true_callable() if predicate else false_callable()

def tolist(self, tensor_in):
if _currently_jitting():
# .aval is the abstract value and has a little nicer representation
return tensor_in.aval
try:
return jnp.asarray(tensor_in).tolist()
except (TypeError, ValueError):
Expand Down
18 changes: 18 additions & 0 deletions tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,21 @@ def test_backend_array_type(backend):
def test_tensor_array_types():
# can't really assert the content of them so easily
assert pyhf.tensor.array_types


@pytest.mark.only_jax
def test_jax_data_shape_mismatch_during_jitting(backend):
"""
Validate that during JAX tracing time pyhf doesn't try
to convert the data to a list, which is not possible with tracers,
for a shape mismatch.
Instead, return the tracer itself for a proper error message.
Issue: https://github.com/scikit-hep/pyhf/issues/1422
PR: https://github.com/scikit-hep/pyhf/pull/2580
"""
model = pyhf.simplemodels.uncorrelated_background([10], [15], [5])
with pytest.raises(
pyhf.exceptions.InvalidPdfData,
match="eval failed as data has len 1 but 2 was expected",
):
pyhf.infer.mle.fit([12.5], model)