From 44ef21e3140b4e51fa5941d16d92ea779117a043 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Fri, 21 Mar 2025 17:19:22 -0400 Subject: [PATCH 01/11] fix jax backend tolist for tracers in logging --- src/pyhf/tensor/jax_backend.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 2a85006c04..bf08a75648 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -2,6 +2,7 @@ config.update('jax_enable_x64', True) +import jax from jax import Array import jax.numpy as jnp from jax.scipy.special import gammaln, xlogy @@ -14,6 +15,10 @@ log = logging.getLogger(__name__) +def currently_jitting(): + return isinstance(jnp.array(1) + 1, jax.core.Tracer) + + class _BasicPoisson: def __init__(self, rate): self.rate = rate @@ -184,6 +189,9 @@ def conditional(self, predicate, true_callable, false_callable): return true_callable() if predicate else false_callable() def tolist(self, tensor_in): + if currently_jitting(): + # .aval is the abstract value and has a little nicer representation + return tensor_in.aval try: return jnp.asarray(tensor_in).tolist() except (TypeError, ValueError): From 68fced7cd023f6e4396c174fa608d52c67a74aec Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 15 Oct 2025 10:06:18 +0200 Subject: [PATCH 02/11] prefix currently_jitting with an underscore --- src/pyhf/tensor/jax_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index bf08a75648..f088669c6b 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -15,7 +15,7 @@ log = logging.getLogger(__name__) -def currently_jitting(): +def _currently_jitting(): return isinstance(jnp.array(1) + 1, jax.core.Tracer) @@ -189,7 +189,7 @@ def conditional(self, predicate, true_callable, false_callable): return true_callable() if predicate else false_callable() def tolist(self, tensor_in): - if currently_jitting(): + if _currently_jitting(): # .aval is the abstract value and has a little nicer representation return tensor_in.aval try: From ed2d895d9fcf126afdf4cd9b22635dbeed6869bd Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 15 Oct 2025 10:14:32 +0200 Subject: [PATCH 03/11] add test --- tests/test_backends.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_backends.py b/tests/test_backends.py index 518a8a759b..8d73e855bb 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -84,3 +84,15 @@ def test_backend_array_type(backend): def test_tensor_array_types(): # can't really assert the content of them so easily assert pyhf.tensor.array_types + + +def test_jax_data_shape_mismatch_during_jitting(): + # Issue: https://github.com/scikit-hep/pyhf/issues/1422 + # PR: https://github.com/scikit-hep/pyhf/pull/2580 + pyhf.set_backend("jax") + model = pyhf.simplemodels.uncorrelated_background([10], [15], [5]) + with pytest.raises( + pyhf.exceptions.InvalidPdfData, + match="eval failed as data has len 1 but 2 was expected", + ): + pyhf.infer.mle.fit([12.5], model) From a09a4226c9f2ad3bd35329b55b1790fe8a712e14 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 15 Oct 2025 11:27:13 +0200 Subject: [PATCH 04/11] style: Scope imports to a finer level --- src/pyhf/tensor/jax_backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index f088669c6b..168b4a819d 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -2,7 +2,7 @@ config.update('jax_enable_x64', True) -import jax +from jax.core import Tracer from jax import Array import jax.numpy as jnp from jax.scipy.special import gammaln, xlogy @@ -16,7 +16,7 @@ def _currently_jitting(): - return isinstance(jnp.array(1) + 1, jax.core.Tracer) + return isinstance(jnp.array(1) + 1, Tracer) class _BasicPoisson: From c1b3176b0409bd680d47923a4b865b1de422578c Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 15 Oct 2025 11:31:50 +0200 Subject: [PATCH 05/11] test: Use pytest.mark functionality to have only the jax backend run --- tests/test_backends.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index 8d73e855bb..e3d01483ec 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -86,10 +86,9 @@ def test_tensor_array_types(): assert pyhf.tensor.array_types +@pytest.mark.only_jax def test_jax_data_shape_mismatch_during_jitting(): # Issue: https://github.com/scikit-hep/pyhf/issues/1422 - # PR: https://github.com/scikit-hep/pyhf/pull/2580 - pyhf.set_backend("jax") model = pyhf.simplemodels.uncorrelated_background([10], [15], [5]) with pytest.raises( pyhf.exceptions.InvalidPdfData, From c1a9d0793234e3f4d6929f99c48e98c10064c976 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 15 Oct 2025 11:45:42 +0200 Subject: [PATCH 06/11] test: Scope test to JAX backend --- tests/test_backends.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index e3d01483ec..9b5990fc3a 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -87,8 +87,13 @@ def test_tensor_array_types(): @pytest.mark.only_jax -def test_jax_data_shape_mismatch_during_jitting(): - # Issue: https://github.com/scikit-hep/pyhf/issues/1422 +def test_jax_data_shape_mismatch_during_jitting(backend): + """ + Validate that during JAX tracingg time the correct form + of the tracer is returned. + Issue: https://github.com/scikit-hep/pyhf/issues/1422 + PR: https://github.com/scikit-hep/pyhf/pull/2580 + """ model = pyhf.simplemodels.uncorrelated_background([10], [15], [5]) with pytest.raises( pyhf.exceptions.InvalidPdfData, From 39bcecb81d3d8bf94544d3816867061b96bcab12 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 15 Oct 2025 12:02:09 +0200 Subject: [PATCH 07/11] fix: fixup typo --- tests/test_backends.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index 9b5990fc3a..7b199f9fe4 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -89,7 +89,7 @@ def test_tensor_array_types(): @pytest.mark.only_jax def test_jax_data_shape_mismatch_during_jitting(backend): """ - Validate that during JAX tracingg time the correct form + Validate that during JAX tracing time the correct form of the tracer is returned. Issue: https://github.com/scikit-hep/pyhf/issues/1422 PR: https://github.com/scikit-hep/pyhf/pull/2580 From ddded5501c8fe8c79e3b1de62bb2754a02da3dd0 Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 15 Oct 2025 14:06:34 +0200 Subject: [PATCH 08/11] add small doc string & simplify '_currently_jitting' --- src/pyhf/tensor/jax_backend.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 168b4a819d..7d47b17b47 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -16,7 +16,8 @@ def _currently_jitting(): - return isinstance(jnp.array(1) + 1, Tracer) + """JAX turns arrays into Tracers during jit-compilation, so we can check for that""" + return isinstance(jnp.array(1), Tracer) class _BasicPoisson: From f2fe00fd6f22a9f65b99d27d47e172105d80c8cc Mon Sep 17 00:00:00 2001 From: pfackeldey Date: Wed, 15 Oct 2025 14:10:52 +0200 Subject: [PATCH 09/11] improve note for the test --- tests/test_backends.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_backends.py b/tests/test_backends.py index 7b199f9fe4..3f3719d9a8 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -89,8 +89,10 @@ def test_tensor_array_types(): @pytest.mark.only_jax def test_jax_data_shape_mismatch_during_jitting(backend): """ - Validate that during JAX tracing time the correct form - of the tracer is returned. + Validate that during JAX tracing time pyhf doesn't try + to convert the data to a list, which is not possible with tracers, + for a shape mismatch. + Instead, we return the tracer itself for a proper error message. Issue: https://github.com/scikit-hep/pyhf/issues/1422 PR: https://github.com/scikit-hep/pyhf/pull/2580 """ From ac5e9e61402ec29205259bc85a22f6b5355e06b9 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 15 Oct 2025 14:37:48 +0200 Subject: [PATCH 10/11] style: Use passive voice --- src/pyhf/tensor/jax_backend.py | 4 +++- tests/test_backends.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pyhf/tensor/jax_backend.py b/src/pyhf/tensor/jax_backend.py index 7d47b17b47..e616302446 100644 --- a/src/pyhf/tensor/jax_backend.py +++ b/src/pyhf/tensor/jax_backend.py @@ -16,7 +16,9 @@ def _currently_jitting(): - """JAX turns arrays into Tracers during jit-compilation, so we can check for that""" + """ + JAX turns arrays into Tracers during jit-compilation, so check for that. + """ return isinstance(jnp.array(1), Tracer) diff --git a/tests/test_backends.py b/tests/test_backends.py index 3f3719d9a8..dbde580cc1 100644 --- a/tests/test_backends.py +++ b/tests/test_backends.py @@ -92,7 +92,7 @@ def test_jax_data_shape_mismatch_during_jitting(backend): Validate that during JAX tracing time pyhf doesn't try to convert the data to a list, which is not possible with tracers, for a shape mismatch. - Instead, we return the tracer itself for a proper error message. + Instead, return the tracer itself for a proper error message. Issue: https://github.com/scikit-hep/pyhf/issues/1422 PR: https://github.com/scikit-hep/pyhf/pull/2580 """ From f875bbfbf3c8b754a2cea3fa837a27d6f6ff9563 Mon Sep 17 00:00:00 2001 From: Matthew Feickert Date: Wed, 15 Oct 2025 14:38:39 +0200 Subject: [PATCH 11/11] docs: Add Peter Fackeldey to contributors list --- docs/contributors.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/contributors.rst b/docs/contributors.rst index 9554373417..6e134e381f 100644 --- a/docs/contributors.rst +++ b/docs/contributors.rst @@ -36,3 +36,4 @@ Contributors include: - Lorenz Gaertner - Melissa Weber Mendonça - Matthias Bussonnier +- Peter Fackeldey