Skip to content

Commit 0d583ca

Browse files
committed
Fix unboxing of numpy boolean Scalars in C-backend
1 parent d8501d1 commit 0d583ca

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

pytensor/scalar/basic.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -484,20 +484,32 @@ def c_extract(self, name, sub, check_input=True, **kwargs):
484484
def c_sync(self, name, sub):
485485
specs = self.dtype_specs()
486486
fail = sub["fail"]
487-
dtype = specs[1]
488-
cls = specs[2]
487+
(np_dtype, _c_dtype, _cls_name) = specs
488+
np_dtype_num = np.dtype(np_dtype).num
489+
489490
return f"""
490491
Py_XDECREF(py_{name});
491-
py_{name} = PyArrayScalar_New({cls});
492+
493+
PyArray_Descr* {name}_descr = PyArray_DescrFromType({np_dtype_num}); // {np_dtype}
494+
if (!{name}_descr) {{
495+
PyErr_Format(PyExc_RuntimeError, "Could not get descriptor for {np_dtype_num}={np_dtype}");
496+
{fail}
497+
}}
498+
499+
// PyArray_Scalar creates a new scalar object by copying data from the pointer &{name}
500+
py_{name} = PyArray_Scalar(&{name}, {name}_descr, NULL);
501+
502+
// Clean up the descriptor reference (PyArray_DescrFromType returns a new ref)
503+
Py_DECREF({name}_descr);
504+
492505
if (!py_{name})
493506
{{
494507
Py_XINCREF(Py_None);
495508
py_{name} = Py_None;
496509
PyErr_Format(PyExc_MemoryError,
497-
"Instantiation of new Python scalar failed ({dtype})");
510+
"Instantiation of new Python NumPy scalar failed ({np_dtype_num}={np_dtype})");
498511
{fail}
499512
}}
500-
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
501513
"""
502514

503515
def c_cleanup(self, name, sub):
@@ -762,7 +774,7 @@ def c_init_code(self, **kwargs):
762774
return ["import_array();"]
763775

764776
def c_code_cache_version(self):
765-
return (14, np.__version__)
777+
return (15, np.__version__)
766778

767779
def get_shape_info(self, obj):
768780
return obj.itemsize

pytensor/tensor/signal/conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def perform(self, node, inputs, outputs):
256256
in1, in2, full_mode = inputs
257257

258258
# TODO: Why is .item() needed?
259-
mode: Literal["full", "valid", "same"] = "full" if full_mode.item() else "valid"
259+
mode: Literal["full", "valid", "same"] = "full" if full_mode else "valid"
260260
outputs[0][0] = scipy_convolve(in1, in2, mode=mode, method=self.method)
261261

262262

tests/tensor/test_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2209,6 +2209,13 @@ def test_ScalarFromTensor(cast_policy):
22092209
scalar_from_tensor(vector())
22102210

22112211

2212+
def test_bool_scalar_from_tensor():
2213+
x = scalar("x", dtype="bool")
2214+
fn = function([x], scalar_from_tensor(x))
2215+
assert fn(np.array(True, dtype=bool))
2216+
assert not fn(np.array(False, dtype=bool))
2217+
2218+
22122219
def test_op_cache():
22132220
# TODO: What is this actually testing?
22142221
# trigger bug in ticket #162

0 commit comments

Comments
 (0)