Skip to content

Commit ecbc6a3

Browse files
Move ti.allclose() to dpctl_ext.tensor
1 parent 5a897c8 commit ecbc6a3

File tree

2 files changed

+177
-0
lines changed

2 files changed

+177
-0
lines changed

dpctl_ext/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@
181181
)
182182
from ._sorting import argsort, sort, top_k
183183
from ._statistical_functions import mean, std, var
184+
from ._testing import allclose
184185
from ._type_utils import can_cast, finfo, iinfo, isdtype, result_type
185186

186187
__all__ = [
@@ -189,6 +190,7 @@
189190
"acosh",
190191
"add",
191192
"all",
193+
"allclose",
192194
"angle",
193195
"any",
194196
"arange",

dpctl_ext/tensor/_testing.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# *****************************************************************************
2+
# Copyright (c) 2026, Intel Corporation
3+
# All rights reserved.
4+
#
5+
# Redistribution and use in source and binary forms, with or without
6+
# modification, are permitted provided that the following conditions are met:
7+
# - Redistributions of source code must retain the above copyright notice,
8+
# this list of conditions and the following disclaimer.
9+
# - Redistributions in binary form must reproduce the above copyright notice,
10+
# this list of conditions and the following disclaimer in the documentation
11+
# and/or other materials provided with the distribution.
12+
# - Neither the name of the copyright holder nor the names of its contributors
13+
# may be used to endorse or promote products derived from this software
14+
# without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17+
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19+
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
20+
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
21+
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
22+
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
23+
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
24+
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
25+
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
26+
# THE POSSIBILITY OF SUCH DAMAGE.
27+
# *****************************************************************************
28+
29+
import dpctl.tensor as dpt
30+
import dpctl.utils as du
31+
import numpy as np
32+
33+
# TODO: revert to `import dpctl.tensor...`
34+
# when dpnp fully migrates dpctl/tensor
35+
import dpctl_ext.tensor as dpt_ext
36+
37+
from ._manipulation_functions import _broadcast_shape_impl
38+
from ._type_utils import _to_device_supported_dtype
39+
40+
41+
def _allclose_complex_fp(z1, z2, atol, rtol, equal_nan):
42+
z1r = dpt.real(z1)
43+
z1i = dpt.imag(z1)
44+
z2r = dpt.real(z2)
45+
z2i = dpt.imag(z2)
46+
if equal_nan:
47+
check1 = dpt_ext.all(
48+
dpt_ext.isnan(z1r) == dpt_ext.isnan(z2r)
49+
) and dpt_ext.all(dpt_ext.isnan(z1i) == dpt_ext.isnan(z2i))
50+
else:
51+
check1 = (
52+
dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z1r)))
53+
and dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z1i)))
54+
) and (
55+
dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z2r)))
56+
and dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z2i)))
57+
)
58+
if not check1:
59+
return check1
60+
mr = dpt_ext.isinf(z1r)
61+
mi = dpt_ext.isinf(z1i)
62+
check2 = dpt_ext.all(mr == dpt_ext.isinf(z2r)) and dpt_ext.all(
63+
mi == dpt_ext.isinf(z2i)
64+
)
65+
if not check2:
66+
return check2
67+
check3 = dpt_ext.all(z1r[mr] == z2r[mr]) and dpt_ext.all(z1i[mi] == z2i[mi])
68+
if not check3:
69+
return check3
70+
mr = dpt_ext.isfinite(z1r)
71+
mi = dpt_ext.isfinite(z1i)
72+
mv1 = z1r[mr]
73+
mv2 = z2r[mr]
74+
check4 = dpt_ext.all(
75+
dpt_ext.abs(mv1 - mv2)
76+
< dpt_ext.maximum(
77+
atol, rtol * dpt_ext.maximum(dpt_ext.abs(mv1), dpt_ext.abs(mv2))
78+
)
79+
)
80+
if not check4:
81+
return check4
82+
mv1 = z1i[mi]
83+
mv2 = z2i[mi]
84+
check5 = dpt_ext.all(
85+
dpt_ext.abs(mv1 - mv2)
86+
<= dpt_ext.maximum(
87+
atol, rtol * dpt_ext.maximum(dpt_ext.abs(mv1), dpt_ext.abs(mv2))
88+
)
89+
)
90+
return check5
91+
92+
93+
def _allclose_real_fp(r1, r2, atol, rtol, equal_nan):
94+
if equal_nan:
95+
check1 = dpt_ext.all(dpt_ext.isnan(r1) == dpt_ext.isnan(r2))
96+
else:
97+
check1 = dpt_ext.logical_not(
98+
dpt_ext.any(dpt_ext.isnan(r1))
99+
) and dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(r2)))
100+
if not check1:
101+
return check1
102+
mr = dpt_ext.isinf(r1)
103+
check2 = dpt_ext.all(mr == dpt_ext.isinf(r2))
104+
if not check2:
105+
return check2
106+
check3 = dpt_ext.all(r1[mr] == r2[mr])
107+
if not check3:
108+
return check3
109+
m = dpt_ext.isfinite(r1)
110+
mv1 = r1[m]
111+
mv2 = r2[m]
112+
check4 = dpt_ext.all(
113+
dpt_ext.abs(mv1 - mv2)
114+
<= dpt_ext.maximum(
115+
atol, rtol * dpt_ext.maximum(dpt_ext.abs(mv1), dpt_ext.abs(mv2))
116+
)
117+
)
118+
return check4
119+
120+
121+
def _allclose_others(r1, r2):
122+
return dpt_ext.all(r1 == r2)
123+
124+
125+
def allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False):
126+
"""allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False)
127+
128+
Returns True if two arrays are element-wise equal within tolerances.
129+
130+
The testing is based on the following elementwise comparison:
131+
132+
abs(a - b) <= max(atol, rtol * max(abs(a), abs(b)))
133+
"""
134+
if not isinstance(a1, dpt.usm_ndarray):
135+
raise TypeError(
136+
f"Expected dpctl.tensor.usm_ndarray type, got {type(a1)}."
137+
)
138+
if not isinstance(a2, dpt.usm_ndarray):
139+
raise TypeError(
140+
f"Expected dpctl.tensor.usm_ndarray type, got {type(a2)}."
141+
)
142+
atol = float(atol)
143+
rtol = float(rtol)
144+
if atol < 0.0 or rtol < 0.0:
145+
raise ValueError(
146+
"Absolute and relative tolerances must be non-negative"
147+
)
148+
equal_nan = bool(equal_nan)
149+
exec_q = du.get_execution_queue(tuple(a.sycl_queue for a in (a1, a2)))
150+
if exec_q is None:
151+
raise du.ExecutionPlacementError(
152+
"Execution placement can not be unambiguously inferred "
153+
"from input arguments."
154+
)
155+
res_sh = _broadcast_shape_impl([a1.shape, a2.shape])
156+
b1 = a1
157+
b2 = a2
158+
if b1.dtype == b2.dtype:
159+
res_dt = b1.dtype
160+
else:
161+
res_dt = np.promote_types(b1.dtype, b2.dtype)
162+
res_dt = _to_device_supported_dtype(res_dt, exec_q.sycl_device)
163+
b1 = dpt_ext.astype(b1, res_dt)
164+
b2 = dpt_ext.astype(b2, res_dt)
165+
166+
b1 = dpt_ext.broadcast_to(b1, res_sh)
167+
b2 = dpt_ext.broadcast_to(b2, res_sh)
168+
169+
k = b1.dtype.kind
170+
if k == "c":
171+
return _allclose_complex_fp(b1, b2, atol, rtol, equal_nan)
172+
elif k == "f":
173+
return _allclose_real_fp(b1, b2, atol, rtol, equal_nan)
174+
else:
175+
return _allclose_others(b1, b2)

0 commit comments

Comments
 (0)