From 42485a75ee504a3356e42d42e02a83b028a19171 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 22 Sep 2025 10:52:59 -0700 Subject: [PATCH 1/9] add manual patching to mkl_fft --- mkl_fft/__init__.py | 11 ++++ mkl_fft/_patch_numpy.py | 122 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+) create mode 100644 mkl_fft/_patch_numpy.py diff --git a/mkl_fft/__init__.py b/mkl_fft/__init__.py index 04586ea..14a1462 100644 --- a/mkl_fft/__init__.py +++ b/mkl_fft/__init__.py @@ -39,10 +39,17 @@ rfft2, rfftn, ) +from ._patch_numpy import ( + is_patched, + mkl_fft, + patch_numpy_fft, + restore_numpy_fft, +) from ._version import __version__ import mkl_fft.interfaces # isort: skip + __all__ = [ "fft", "ifft", @@ -57,6 +64,10 @@ "rfftn", "irfftn", "interfaces", + "mkl_fft", + "patch_numpy_fft", + "restore_numpy_fft", + "is_patched", ] del _init_helper diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py new file mode 100644 index 0000000..fa220e0 --- /dev/null +++ b/mkl_fft/_patch_numpy.py @@ -0,0 +1,122 @@ +#!/usr/bin/env python +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Define functions for patching NumPy with MKL-based NumPy interface.""" + +from contextlib import ContextDecorator +from threading import local as threading_local + +import numpy as np + +import mkl_fft.interfaces.numpy_fft as _nfft + +_tls = threading_local() + + +class _Patch: + """Internal object for patching NumPy with mkl_fft interfaces.""" + + _is_patched = False + __patched_functions__ = _nfft.__all__ + _restore_dict = {} + + def _register_func(self, name, func): + if name not in self.__patched_functions__: + raise ValueError("%s not an mkl_fft function." % name) + f = getattr(np.fft, name) + self._restore_dict[name] = f + setattr(np.fft, name, func) + + def _restore_func(self, name, verbose=False): + if name not in self.__patched_functions__: + raise ValueError("%s not an mkl_fft function." % name) + try: + val = self._restore_dict[name] + except KeyError: + if verbose: + print("failed to restore") + return + else: + if verbose: + print("found and restoring...") + setattr(np.fft, name, val) + + def restore(self, verbose=False): + for name in self._restore_dict.keys(): + self._restore_func(name, verbose=verbose) + self._is_patched = False + + def do_patch(self): + for f in self.__patched_functions__: + self._register_func(f, getattr(_nfft, f)) + self._is_patched = True + + def is_patched(self): + return self._is_patched + + +def _initialize_tls(): + _tls.patch = _Patch() + _tls.initialized = True + + +def _is_tls_initialized(): + return (getattr(_tls, "initialized", None) is not None) and ( + _tls.initialized is True + ) + + +def patch_numpy_fft(verbose=False): + if verbose: + print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.") + print("Please direct bug reports to https://github.com/IntelPython/mkl_fft") + if not _is_tls_initialized(): + _initialize_tls() + _tls.patch.do_patch() + + +def restore_numpy_fft(verbose=False): + if verbose: + print("Now restoring original NumPy FFT submodule.") + if not _is_tls_initialized(): + _initialize_tls() + _tls.patch.restore(verbose=verbose) + + +def is_patched(): + if not _is_tls_initialized(): + _initialize_tls() + return _tls.patch.is_patched() + + +class mkl_fft(ContextDecorator): + def __enter__(self): + patch_numpy_fft() + return self + + def __exit__(self, *exc): + restore_numpy_fft() + return False From db76eb2919970004b626ed0b233b907cab7267ef Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Mon, 22 Sep 2025 12:05:08 -0700 Subject: [PATCH 2/9] add test for NumPy patching function --- mkl_fft/tests/test_patch.py | 43 +++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 mkl_fft/tests/test_patch.py diff --git a/mkl_fft/tests/test_patch.py b/mkl_fft/tests/test_patch.py new file mode 100644 index 0000000..3a5206d --- /dev/null +++ b/mkl_fft/tests/test_patch.py @@ -0,0 +1,43 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np + +import mkl_fft +import mkl_fft.interfaces.numpy_fft as _nfft + + +def test_patch(): + mkl_fft.restore_numpy_fft() + assert not mkl_fft.is_patched() + assert (np.fft.fft.__module__ == "numpy.fft") + + mkl_fft.patch_numpy_fft() # Enable mkl_fft in Numpy + assert mkl_fft.is_patched() + assert (np.fft.fft.__module__ == _nfft.fft.__module__) + + mkl_fft.restore_numpy_fft() # Disable mkl_fft in Numpy + assert not mkl_fft.is_patched() + assert (np.fft.fft.__module__ == "numpy.fft") From efbd93b5b2f04a3b75586a783ac6523d3fee04f9 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Fri, 20 Feb 2026 01:09:31 -0800 Subject: [PATCH 3/9] fix patching test --- mkl_fft/_patch_numpy.py | 4 +++- mkl_fft/tests/test_patch.py | 7 +++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py index fa220e0..071993e 100644 --- a/mkl_fft/_patch_numpy.py +++ b/mkl_fft/_patch_numpy.py @@ -92,7 +92,9 @@ def _is_tls_initialized(): def patch_numpy_fft(verbose=False): if verbose: print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.") - print("Please direct bug reports to https://github.com/IntelPython/mkl_fft") + print( + "Please direct bug reports to https://github.com/IntelPython/mkl_fft" + ) if not _is_tls_initialized(): _initialize_tls() _tls.patch.do_patch() diff --git a/mkl_fft/tests/test_patch.py b/mkl_fft/tests/test_patch.py index 3a5206d..5a4a090 100644 --- a/mkl_fft/tests/test_patch.py +++ b/mkl_fft/tests/test_patch.py @@ -30,14 +30,13 @@ def test_patch(): - mkl_fft.restore_numpy_fft() + old_module = np.fft.fft.__module__ assert not mkl_fft.is_patched() - assert (np.fft.fft.__module__ == "numpy.fft") mkl_fft.patch_numpy_fft() # Enable mkl_fft in Numpy assert mkl_fft.is_patched() - assert (np.fft.fft.__module__ == _nfft.fft.__module__) + assert np.fft.fft.__module__ == _nfft.fft.__module__ mkl_fft.restore_numpy_fft() # Disable mkl_fft in Numpy assert not mkl_fft.is_patched() - assert (np.fft.fft.__module__ == "numpy.fft") + assert np.fft.fft.__module__ == old_module From 23edc895fc3f82e34945370bc2751eff980e68d2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 3 Mar 2026 01:28:18 -0800 Subject: [PATCH 4/9] address feedback about patching * improve thread-safety of patching * gate repeated patch calls there still exist problematic edge cases (race condition where one thread restores while another is using patched functions) --- mkl_fft/__init__.py | 1 - mkl_fft/_patch_numpy.py | 137 +++++++++++++++++++++++------------- mkl_fft/tests/test_patch.py | 19 +++++ 3 files changed, 106 insertions(+), 51 deletions(-) diff --git a/mkl_fft/__init__.py b/mkl_fft/__init__.py index 14a1462..cefcc47 100644 --- a/mkl_fft/__init__.py +++ b/mkl_fft/__init__.py @@ -49,7 +49,6 @@ import mkl_fft.interfaces # isort: skip - __all__ = [ "fft", "ifft", diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py index 071993e..ef089e4 100644 --- a/mkl_fft/_patch_numpy.py +++ b/mkl_fft/_patch_numpy.py @@ -27,94 +27,131 @@ """Define functions for patching NumPy with MKL-based NumPy interface.""" from contextlib import ContextDecorator -from threading import local as threading_local +from threading import Lock import numpy as np import mkl_fft.interfaces.numpy_fft as _nfft -_tls = threading_local() - -class _Patch: - """Internal object for patching NumPy with mkl_fft interfaces.""" - - _is_patched = False - __patched_functions__ = _nfft.__all__ - _restore_dict = {} +class _GlobalPatch(ContextDecorator): + def __init__(self): + self._lock = Lock() + self._patch_count = 0 + self._restore_dict = {} + # make _patched_functions a tuple (immutable) + self._patched_functions = tuple(_nfft.__all__) def _register_func(self, name, func): - if name not in self.__patched_functions__: - raise ValueError("%s not an mkl_fft function." % name) - f = getattr(np.fft, name) - self._restore_dict[name] = f + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_fft function.") + if name not in self._restore_dict: + self._restore_dict[name] = getattr(np.fft, name) setattr(np.fft, name, func) def _restore_func(self, name, verbose=False): - if name not in self.__patched_functions__: - raise ValueError("%s not an mkl_fft function." % name) + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_fft function.") try: val = self._restore_dict[name] except KeyError: if verbose: - print("failed to restore") + print(f"failed to restore {name}") return else: if verbose: - print("found and restoring...") + print(f"found and restoring {name}...") setattr(np.fft, name, val) - def restore(self, verbose=False): - for name in self._restore_dict.keys(): - self._restore_func(name, verbose=verbose) - self._is_patched = False - - def do_patch(self): - for f in self.__patched_functions__: - self._register_func(f, getattr(_nfft, f)) - self._is_patched = True + def do_patch(self, verbose=False): + with self._lock: + if self._patch_count == 0: + if verbose: + print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.") + print( + "Please direct bug reports to https://github.com/IntelPython/mkl_fft" + ) + for f in self._patched_functions: + self._register_func(f, getattr(_nfft, f)) + self._patch_count += 1 + + def do_restore(self, verbose=False): + with self._lock: + if self._patch_count > 0: + self._patch_count -= 1 + if self._patch_count == 0: + if verbose: + print("Now restoring original NumPy FFT submodule.") + for name in tuple(self._restore_dict): + self._restore_func(name, verbose=verbose) + self._restore_dict.clear() def is_patched(self): - return self._is_patched + with self._lock: + return self._patch_count > 0 + def __enter__(self): + self.do_patch() + return self -def _initialize_tls(): - _tls.patch = _Patch() - _tls.initialized = True + def __exit__(self, *exc): + self.do_restore() + return False -def _is_tls_initialized(): - return (getattr(_tls, "initialized", None) is not None) and ( - _tls.initialized is True - ) +_patch = _GlobalPatch() def patch_numpy_fft(verbose=False): - if verbose: - print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.") - print( - "Please direct bug reports to https://github.com/IntelPython/mkl_fft" - ) - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.do_patch() + """Patch NumPy's fft submodule with mkl_fft's numpy_interface. + + Parameters + ---------- + verbose : bool, optional + print message when starting the patching process. + + """ + _patch.do_patch(verbose=verbose) def restore_numpy_fft(verbose=False): - if verbose: - print("Now restoring original NumPy FFT submodule.") - if not _is_tls_initialized(): - _initialize_tls() - _tls.patch.restore(verbose=verbose) + """ + Restore NumPy's fft submodule to its original implementations. + + Parameters + ---------- + verbose : bool, optional + print message when starting restoration process. + + """ + _patch.do_restore(verbose=verbose) def is_patched(): - if not _is_tls_initialized(): - _initialize_tls() - return _tls.patch.is_patched() + """Return True if NumPy's fft submodule is currently patched by mkl_fft.""" + return _patch.is_patched() class mkl_fft(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy fft submodule + with MKL-based implementations. + + Examples + -------- + >>> import mkl_fft + >>> mkl_fft.is_patched() + # False + + >>> with mkl_fft.mkl_fft(): # Enable mkl_fft in Numpy + >>> print(mkl_fft.is_patched()) + # True + + >>> mkl_fft.is_patched() + # False + + """ + def __enter__(self): patch_numpy_fft() return self diff --git a/mkl_fft/tests/test_patch.py b/mkl_fft/tests/test_patch.py index 5a4a090..769adbf 100644 --- a/mkl_fft/tests/test_patch.py +++ b/mkl_fft/tests/test_patch.py @@ -40,3 +40,22 @@ def test_patch(): mkl_fft.restore_numpy_fft() # Disable mkl_fft in Numpy assert not mkl_fft.is_patched() assert np.fft.fft.__module__ == old_module + + +def test_patch_redundant_patching(): + old_module = np.fft.fft.__module__ + assert not mkl_fft.is_patched() + + mkl_fft.patch_numpy_fft() + mkl_fft.patch_numpy_fft() + + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + mkl_fft.restore_numpy_fft() + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + mkl_fft.restore_numpy_fft() + assert not mkl_fft.is_patched() + assert np.fft.fft.__module__ == old_module From 14d37054bfa61d80a42268bbfa2a18442ef0f7c2 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 3 Mar 2026 02:04:05 -0800 Subject: [PATCH 5/9] use thread-local storage for bookkeeping patch calls per thread --- mkl_fft/_patch_numpy.py | 44 +++++++++++++++++++++++++++++++++++------ 1 file changed, 38 insertions(+), 6 deletions(-) diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py index ef089e4..67074d0 100644 --- a/mkl_fft/_patch_numpy.py +++ b/mkl_fft/_patch_numpy.py @@ -27,7 +27,7 @@ """Define functions for patching NumPy with MKL-based NumPy interface.""" from contextlib import ContextDecorator -from threading import Lock +from threading import Lock, local import numpy as np @@ -41,6 +41,7 @@ def __init__(self): self._restore_dict = {} # make _patched_functions a tuple (immutable) self._patched_functions = tuple(_nfft.__all__) + self._tls = local() def _register_func(self, name, func): if name not in self._patched_functions: @@ -65,20 +66,34 @@ def _restore_func(self, name, verbose=False): def do_patch(self, verbose=False): with self._lock: + local_count = getattr(self._tls, "local_count", 0) if self._patch_count == 0: if verbose: - print("Now patching NumPy FFT submodule with mkl_fft NumPy interface.") print( - "Please direct bug reports to https://github.com/IntelPython/mkl_fft" + "Now patching NumPy FFT submodule with mkl_fft NumPy " + "interface." + ) + print( + "Please direct bug reports to " + "https://github.com/IntelPython/mkl_fft" ) for f in self._patched_functions: self._register_func(f, getattr(_nfft, f)) self._patch_count += 1 + self._tls.local_count = local_count + 1 def do_restore(self, verbose=False): with self._lock: - if self._patch_count > 0: - self._patch_count -= 1 + local_count = getattr(self._tls, "local_count", 0) + if local_count <= 0: + if verbose: + print( + "Warning: restore_numpy_fft called more times than " + "patch_numpy_fft in this thread." + ) + return + self._tls.local_count -= 1 + self._patch_count -= 1 if self._patch_count == 0: if verbose: print("Now restoring original NumPy FFT submodule.") @@ -103,13 +118,22 @@ def __exit__(self, *exc): def patch_numpy_fft(verbose=False): - """Patch NumPy's fft submodule with mkl_fft's numpy_interface. + """ + Patch NumPy's fft submodule with mkl_fft's numpy_interface. Parameters ---------- verbose : bool, optional print message when starting the patching process. + Notes + ----- + This function uses reference-counted semantics. Each call increments a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_fft` and `restore_numpy_fft`. + + In multi-threaded programs, prefer the `mkl_fft` context manager. + """ _patch.do_patch(verbose=verbose) @@ -123,6 +147,14 @@ def restore_numpy_fft(verbose=False): verbose : bool, optional print message when starting restoration process. + Notes + ----- + This function uses reference-counted semantics. Each call decrements a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_fft` and `restore_numpy_fft`. + + In multi-threaded programs, prefer the `mkl_fft` context manager. + """ _patch.do_restore(verbose=verbose) From f9ec9cd6819b175bb9fe50d0fa9731d4e4ae0599 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 3 Mar 2026 17:09:54 -0800 Subject: [PATCH 6/9] drop ContextDecorator as a parent to GlobalPatch leftover from idea to use GlobalPatch as parent to mkl_fft context manager --- mkl_fft/_patch_numpy.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py index 67074d0..7bdbb46 100644 --- a/mkl_fft/_patch_numpy.py +++ b/mkl_fft/_patch_numpy.py @@ -34,7 +34,7 @@ import mkl_fft.interfaces.numpy_fft as _nfft -class _GlobalPatch(ContextDecorator): +class _GlobalPatch: def __init__(self): self._lock = Lock() self._patch_count = 0 @@ -105,14 +105,6 @@ def is_patched(self): with self._lock: return self._patch_count > 0 - def __enter__(self): - self.do_patch() - return self - - def __exit__(self, *exc): - self.do_restore() - return False - _patch = _GlobalPatch() From 25a8f4947e2c156ccd3e629b4e8a3489a8b70fbb Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 3 Mar 2026 17:31:26 -0800 Subject: [PATCH 7/9] add patching to changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index e0332c4..d20d4ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [dev] - YYYY-MM-DD +### Added +* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224) + ### Removed * Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243) From 977169756e6a1f27655f5dd15a2be55fd75293be Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Tue, 3 Mar 2026 17:33:01 -0800 Subject: [PATCH 8/9] remove unnecessary shebangs --- mkl_fft/__init__.py | 1 - mkl_fft/_fft_utils.py | 1 - mkl_fft/_mkl_fft.py | 1 - mkl_fft/_patch_numpy.py | 1 - mkl_fft/interfaces/_float_utils.py | 1 - mkl_fft/interfaces/_numpy_fft.py | 1 - mkl_fft/interfaces/_numpy_helper.py | 1 - mkl_fft/interfaces/_scipy_fft.py | 1 - mkl_fft/interfaces/numpy_fft.py | 1 - mkl_fft/interfaces/scipy_fft.py | 1 - mkl_fft/tests/helper.py | 1 - mkl_fft/tests/test_fft1d.py | 1 - mkl_fft/tests/test_fftnd.py | 1 - 13 files changed, 13 deletions(-) diff --git a/mkl_fft/__init__.py b/mkl_fft/__init__.py index cefcc47..c2e418f 100644 --- a/mkl_fft/__init__.py +++ b/mkl_fft/__init__.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index a012e31..ad6a055 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2025, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/_mkl_fft.py b/mkl_fft/_mkl_fft.py index 1cd49b9..3ab60c9 100644 --- a/mkl_fft/_mkl_fft.py +++ b/mkl_fft/_mkl_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2025, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py index 7bdbb46..5c6b1ca 100644 --- a/mkl_fft/_patch_numpy.py +++ b/mkl_fft/_patch_numpy.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/_float_utils.py b/mkl_fft/interfaces/_float_utils.py index 19e044d..af0f1da 100644 --- a/mkl_fft/interfaces/_float_utils.py +++ b/mkl_fft/interfaces/_float_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/_numpy_fft.py b/mkl_fft/interfaces/_numpy_fft.py index d83fe0b..d752757 100644 --- a/mkl_fft/interfaces/_numpy_fft.py +++ b/mkl_fft/interfaces/_numpy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/_numpy_helper.py b/mkl_fft/interfaces/_numpy_helper.py index 1a67812..eeb154f 100644 --- a/mkl_fft/interfaces/_numpy_helper.py +++ b/mkl_fft/interfaces/_numpy_helper.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index 64ccaf0..8938323 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/numpy_fft.py b/mkl_fft/interfaces/numpy_fft.py index aa74f3d..6c36e65 100644 --- a/mkl_fft/interfaces/numpy_fft.py +++ b/mkl_fft/interfaces/numpy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/scipy_fft.py b/mkl_fft/interfaces/scipy_fft.py index 4adce52..2e4d007 100644 --- a/mkl_fft/interfaces/scipy_fft.py +++ b/mkl_fft/interfaces/scipy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/tests/helper.py b/mkl_fft/tests/helper.py index 6a59664..77eff36 100644 --- a/mkl_fft/tests/helper.py +++ b/mkl_fft/tests/helper.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2025, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/tests/test_fft1d.py b/mkl_fft/tests/test_fft1d.py index 7a1e0fd..16464e7 100644 --- a/mkl_fft/tests/test_fft1d.py +++ b/mkl_fft/tests/test_fft1d.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py index 007f0bd..4f04a11 100644 --- a/mkl_fft/tests/test_fftnd.py +++ b/mkl_fft/tests/test_fftnd.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without From a1626859857eb6597926eaf5153b296e46f9b778 Mon Sep 17 00:00:00 2001 From: Nikita Grigorian Date: Wed, 4 Mar 2026 13:30:06 -0800 Subject: [PATCH 9/9] add test for reentrant behavior with mkl_fft context manager also use from mkl_fft import interfaces to avoid redundant mkl_fft module in namespace --- mkl_fft/__init__.py | 2 +- mkl_fft/tests/test_patch.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/mkl_fft/__init__.py b/mkl_fft/__init__.py index c2e418f..997566b 100644 --- a/mkl_fft/__init__.py +++ b/mkl_fft/__init__.py @@ -46,7 +46,7 @@ ) from ._version import __version__ -import mkl_fft.interfaces # isort: skip +from mkl_fft import interfaces # isort: skip __all__ = [ "fft", diff --git a/mkl_fft/tests/test_patch.py b/mkl_fft/tests/test_patch.py index 769adbf..b68dc2d 100644 --- a/mkl_fft/tests/test_patch.py +++ b/mkl_fft/tests/test_patch.py @@ -59,3 +59,22 @@ def test_patch_redundant_patching(): mkl_fft.restore_numpy_fft() assert not mkl_fft.is_patched() assert np.fft.fft.__module__ == old_module + + +def test_patch_reentrant(): + old_module = np.fft.fft.__module__ + assert not mkl_fft.is_patched() + + with mkl_fft.mkl_fft(): + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + with mkl_fft.mkl_fft(): + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + assert not mkl_fft.is_patched() + assert np.fft.fft.__module__ == old_module