Skip to content

Commit 239f765

Browse files
authored
chore: Add numpy typing plugin (#2566)
* Add numpy typing plugin for improved type checking. - c.f. https://numpy.org/doc/2.1/reference/typing.html * Add additional docstring example.
1 parent 157fc95 commit 239f765

File tree

3 files changed

+7
-1
lines changed

3 files changed

+7
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ warn_unused_configs = true
222222
strict = true
223223
enable_error_code = ["ignore-without-code", "redundant-expr", "truthy-bool"]
224224
warn_unreachable = true
225+
plugins = "numpy.typing.mypy_plugin"
225226

226227
[[tool.mypy.overrides]]
227228
module = [

src/pyhf/tensor/manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ def set_backend(
6565
6666
Example:
6767
>>> import pyhf
68+
>>> pyhf.set_backend(b"jax", precision="32b")
69+
>>> pyhf.tensorlib.name
70+
'jax'
71+
>>> pyhf.tensorlib.precision
72+
'32b'
6873
>>> pyhf.set_backend(pyhf.tensor.numpy_backend())
6974
>>> pyhf.tensorlib.name
7075
'numpy'

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def reset_backend():
8989
def backend(request):
9090
# a better way to get the id? all the backends we have so far for testing
9191
param_ids = request._fixturedef.ids
92-
# the backend we're using: numpy, etc...
92+
# the backend we're using: numpy, jax, etc...
9393
param_id = param_ids[request.param_index]
9494
# name of function being called (with params), the original name is .originalname
9595
func_name = request._pyfuncitem.name

0 commit comments

Comments
 (0)