|
| 1 | +# ***************************************************************************** |
| 2 | +# Copyright (c) 2026, Intel Corporation |
| 3 | +# All rights reserved. |
| 4 | +# |
| 5 | +# Redistribution and use in source and binary forms, with or without |
| 6 | +# modification, are permitted provided that the following conditions are met: |
| 7 | +# - Redistributions of source code must retain the above copyright notice, |
| 8 | +# this list of conditions and the following disclaimer. |
| 9 | +# - Redistributions in binary form must reproduce the above copyright notice, |
| 10 | +# this list of conditions and the following disclaimer in the documentation |
| 11 | +# and/or other materials provided with the distribution. |
| 12 | +# - Neither the name of the copyright holder nor the names of its contributors |
| 13 | +# may be used to endorse or promote products derived from this software |
| 14 | +# without specific prior written permission. |
| 15 | +# |
| 16 | +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" |
| 17 | +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE |
| 18 | +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE |
| 19 | +# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE |
| 20 | +# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR |
| 21 | +# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF |
| 22 | +# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS |
| 23 | +# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN |
| 24 | +# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) |
| 25 | +# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF |
| 26 | +# THE POSSIBILITY OF SUCH DAMAGE. |
| 27 | +# ***************************************************************************** |
| 28 | + |
| 29 | +import dpctl.tensor as dpt |
| 30 | +import dpctl.utils as du |
| 31 | +import numpy as np |
| 32 | + |
| 33 | +# TODO: revert to `import dpctl.tensor...` |
| 34 | +# when dpnp fully migrates dpctl/tensor |
| 35 | +import dpctl_ext.tensor as dpt_ext |
| 36 | + |
| 37 | +from ._manipulation_functions import _broadcast_shape_impl |
| 38 | +from ._type_utils import _to_device_supported_dtype |
| 39 | + |
| 40 | + |
| 41 | +def _allclose_complex_fp(z1, z2, atol, rtol, equal_nan): |
| 42 | + z1r = dpt.real(z1) |
| 43 | + z1i = dpt.imag(z1) |
| 44 | + z2r = dpt.real(z2) |
| 45 | + z2i = dpt.imag(z2) |
| 46 | + if equal_nan: |
| 47 | + check1 = dpt_ext.all( |
| 48 | + dpt_ext.isnan(z1r) == dpt_ext.isnan(z2r) |
| 49 | + ) and dpt_ext.all(dpt_ext.isnan(z1i) == dpt_ext.isnan(z2i)) |
| 50 | + else: |
| 51 | + check1 = ( |
| 52 | + dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z1r))) |
| 53 | + and dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z1i))) |
| 54 | + ) and ( |
| 55 | + dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z2r))) |
| 56 | + and dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(z2i))) |
| 57 | + ) |
| 58 | + if not check1: |
| 59 | + return check1 |
| 60 | + mr = dpt_ext.isinf(z1r) |
| 61 | + mi = dpt_ext.isinf(z1i) |
| 62 | + check2 = dpt_ext.all(mr == dpt_ext.isinf(z2r)) and dpt_ext.all( |
| 63 | + mi == dpt_ext.isinf(z2i) |
| 64 | + ) |
| 65 | + if not check2: |
| 66 | + return check2 |
| 67 | + check3 = dpt_ext.all(z1r[mr] == z2r[mr]) and dpt_ext.all(z1i[mi] == z2i[mi]) |
| 68 | + if not check3: |
| 69 | + return check3 |
| 70 | + mr = dpt_ext.isfinite(z1r) |
| 71 | + mi = dpt_ext.isfinite(z1i) |
| 72 | + mv1 = z1r[mr] |
| 73 | + mv2 = z2r[mr] |
| 74 | + check4 = dpt_ext.all( |
| 75 | + dpt_ext.abs(mv1 - mv2) |
| 76 | + < dpt_ext.maximum( |
| 77 | + atol, rtol * dpt_ext.maximum(dpt_ext.abs(mv1), dpt_ext.abs(mv2)) |
| 78 | + ) |
| 79 | + ) |
| 80 | + if not check4: |
| 81 | + return check4 |
| 82 | + mv1 = z1i[mi] |
| 83 | + mv2 = z2i[mi] |
| 84 | + check5 = dpt_ext.all( |
| 85 | + dpt_ext.abs(mv1 - mv2) |
| 86 | + <= dpt_ext.maximum( |
| 87 | + atol, rtol * dpt_ext.maximum(dpt_ext.abs(mv1), dpt_ext.abs(mv2)) |
| 88 | + ) |
| 89 | + ) |
| 90 | + return check5 |
| 91 | + |
| 92 | + |
| 93 | +def _allclose_real_fp(r1, r2, atol, rtol, equal_nan): |
| 94 | + if equal_nan: |
| 95 | + check1 = dpt_ext.all(dpt_ext.isnan(r1) == dpt_ext.isnan(r2)) |
| 96 | + else: |
| 97 | + check1 = dpt_ext.logical_not( |
| 98 | + dpt_ext.any(dpt_ext.isnan(r1)) |
| 99 | + ) and dpt_ext.logical_not(dpt_ext.any(dpt_ext.isnan(r2))) |
| 100 | + if not check1: |
| 101 | + return check1 |
| 102 | + mr = dpt_ext.isinf(r1) |
| 103 | + check2 = dpt_ext.all(mr == dpt_ext.isinf(r2)) |
| 104 | + if not check2: |
| 105 | + return check2 |
| 106 | + check3 = dpt_ext.all(r1[mr] == r2[mr]) |
| 107 | + if not check3: |
| 108 | + return check3 |
| 109 | + m = dpt_ext.isfinite(r1) |
| 110 | + mv1 = r1[m] |
| 111 | + mv2 = r2[m] |
| 112 | + check4 = dpt_ext.all( |
| 113 | + dpt_ext.abs(mv1 - mv2) |
| 114 | + <= dpt_ext.maximum( |
| 115 | + atol, rtol * dpt_ext.maximum(dpt_ext.abs(mv1), dpt_ext.abs(mv2)) |
| 116 | + ) |
| 117 | + ) |
| 118 | + return check4 |
| 119 | + |
| 120 | + |
| 121 | +def _allclose_others(r1, r2): |
| 122 | + return dpt_ext.all(r1 == r2) |
| 123 | + |
| 124 | + |
| 125 | +def allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False): |
| 126 | + """allclose(a1, a2, atol=1e-8, rtol=1e-5, equal_nan=False) |
| 127 | +
|
| 128 | + Returns True if two arrays are element-wise equal within tolerances. |
| 129 | +
|
| 130 | + The testing is based on the following elementwise comparison: |
| 131 | +
|
| 132 | + abs(a - b) <= max(atol, rtol * max(abs(a), abs(b))) |
| 133 | + """ |
| 134 | + if not isinstance(a1, dpt.usm_ndarray): |
| 135 | + raise TypeError( |
| 136 | + f"Expected dpctl.tensor.usm_ndarray type, got {type(a1)}." |
| 137 | + ) |
| 138 | + if not isinstance(a2, dpt.usm_ndarray): |
| 139 | + raise TypeError( |
| 140 | + f"Expected dpctl.tensor.usm_ndarray type, got {type(a2)}." |
| 141 | + ) |
| 142 | + atol = float(atol) |
| 143 | + rtol = float(rtol) |
| 144 | + if atol < 0.0 or rtol < 0.0: |
| 145 | + raise ValueError( |
| 146 | + "Absolute and relative tolerances must be non-negative" |
| 147 | + ) |
| 148 | + equal_nan = bool(equal_nan) |
| 149 | + exec_q = du.get_execution_queue(tuple(a.sycl_queue for a in (a1, a2))) |
| 150 | + if exec_q is None: |
| 151 | + raise du.ExecutionPlacementError( |
| 152 | + "Execution placement can not be unambiguously inferred " |
| 153 | + "from input arguments." |
| 154 | + ) |
| 155 | + res_sh = _broadcast_shape_impl([a1.shape, a2.shape]) |
| 156 | + b1 = a1 |
| 157 | + b2 = a2 |
| 158 | + if b1.dtype == b2.dtype: |
| 159 | + res_dt = b1.dtype |
| 160 | + else: |
| 161 | + res_dt = np.promote_types(b1.dtype, b2.dtype) |
| 162 | + res_dt = _to_device_supported_dtype(res_dt, exec_q.sycl_device) |
| 163 | + b1 = dpt_ext.astype(b1, res_dt) |
| 164 | + b2 = dpt_ext.astype(b2, res_dt) |
| 165 | + |
| 166 | + b1 = dpt_ext.broadcast_to(b1, res_sh) |
| 167 | + b2 = dpt_ext.broadcast_to(b2, res_sh) |
| 168 | + |
| 169 | + k = b1.dtype.kind |
| 170 | + if k == "c": |
| 171 | + return _allclose_complex_fp(b1, b2, atol, rtol, equal_nan) |
| 172 | + elif k == "f": |
| 173 | + return _allclose_real_fp(b1, b2, atol, rtol, equal_nan) |
| 174 | + else: |
| 175 | + return _allclose_others(b1, b2) |
0 commit comments