diff --git a/CHANGELOG.md b/CHANGELOG.md index e0332c4..d20d4ef 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [dev] - YYYY-MM-DD +### Added +* Added `mkl_fft` patching for NumPy, with `mkl_fft` context manager, `is_patched` query, and `patch_numpy_fft` and `restore_numpy_fft` calls to replace `numpy.fft` calls with calls from `mkl_fft.interfaces.numpy_fft` [gh-224](https://github.com/IntelPython/mkl_fft/pull/224) + ### Removed * Dropped support for Python 3.9 [gh-243](https://github.com/IntelPython/mkl_fft/pull/243) diff --git a/mkl_fft/__init__.py b/mkl_fft/__init__.py index 04586ea..997566b 100644 --- a/mkl_fft/__init__.py +++ b/mkl_fft/__init__.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without @@ -39,9 +38,15 @@ rfft2, rfftn, ) +from ._patch_numpy import ( + is_patched, + mkl_fft, + patch_numpy_fft, + restore_numpy_fft, +) from ._version import __version__ -import mkl_fft.interfaces # isort: skip +from mkl_fft import interfaces # isort: skip __all__ = [ "fft", @@ -57,6 +62,10 @@ "rfftn", "irfftn", "interfaces", + "mkl_fft", + "patch_numpy_fft", + "restore_numpy_fft", + "is_patched", ] del _init_helper diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index a012e31..ad6a055 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2025, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/_mkl_fft.py b/mkl_fft/_mkl_fft.py index 1cd49b9..3ab60c9 100644 --- a/mkl_fft/_mkl_fft.py +++ b/mkl_fft/_mkl_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2025, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/_patch_numpy.py b/mkl_fft/_patch_numpy.py new file mode 100644 index 0000000..5c6b1ca --- /dev/null +++ b/mkl_fft/_patch_numpy.py @@ -0,0 +1,184 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Define functions for patching NumPy with MKL-based NumPy interface.""" + +from contextlib import ContextDecorator +from threading import Lock, local + +import numpy as np + +import mkl_fft.interfaces.numpy_fft as _nfft + + +class _GlobalPatch: + def __init__(self): + self._lock = Lock() + self._patch_count = 0 + self._restore_dict = {} + # make _patched_functions a tuple (immutable) + self._patched_functions = tuple(_nfft.__all__) + self._tls = local() + + def _register_func(self, name, func): + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_fft function.") + if name not in self._restore_dict: + self._restore_dict[name] = getattr(np.fft, name) + setattr(np.fft, name, func) + + def _restore_func(self, name, verbose=False): + if name not in self._patched_functions: + raise ValueError(f"{name} not an mkl_fft function.") + try: + val = self._restore_dict[name] + except KeyError: + if verbose: + print(f"failed to restore {name}") + return + else: + if verbose: + print(f"found and restoring {name}...") + setattr(np.fft, name, val) + + def do_patch(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if self._patch_count == 0: + if verbose: + print( + "Now patching NumPy FFT submodule with mkl_fft NumPy " + "interface." + ) + print( + "Please direct bug reports to " + "https://github.com/IntelPython/mkl_fft" + ) + for f in self._patched_functions: + self._register_func(f, getattr(_nfft, f)) + self._patch_count += 1 + self._tls.local_count = local_count + 1 + + def do_restore(self, verbose=False): + with self._lock: + local_count = getattr(self._tls, "local_count", 0) + if local_count <= 0: + if verbose: + print( + "Warning: restore_numpy_fft called more times than " + "patch_numpy_fft in this thread." + ) + return + self._tls.local_count -= 1 + self._patch_count -= 1 + if self._patch_count == 0: + if verbose: + print("Now restoring original NumPy FFT submodule.") + for name in tuple(self._restore_dict): + self._restore_func(name, verbose=verbose) + self._restore_dict.clear() + + def is_patched(self): + with self._lock: + return self._patch_count > 0 + + +_patch = _GlobalPatch() + + +def patch_numpy_fft(verbose=False): + """ + Patch NumPy's fft submodule with mkl_fft's numpy_interface. + + Parameters + ---------- + verbose : bool, optional + print message when starting the patching process. + + Notes + ----- + This function uses reference-counted semantics. Each call increments a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_fft` and `restore_numpy_fft`. + + In multi-threaded programs, prefer the `mkl_fft` context manager. + + """ + _patch.do_patch(verbose=verbose) + + +def restore_numpy_fft(verbose=False): + """ + Restore NumPy's fft submodule to its original implementations. + + Parameters + ---------- + verbose : bool, optional + print message when starting restoration process. + + Notes + ----- + This function uses reference-counted semantics. Each call decrements a + global patch counter. Restoration requires a matching number of calls + between `patch_numpy_fft` and `restore_numpy_fft`. + + In multi-threaded programs, prefer the `mkl_fft` context manager. + + """ + _patch.do_restore(verbose=verbose) + + +def is_patched(): + """Return True if NumPy's fft submodule is currently patched by mkl_fft.""" + return _patch.is_patched() + + +class mkl_fft(ContextDecorator): + """ + Context manager and decorator to temporarily patch NumPy fft submodule + with MKL-based implementations. + + Examples + -------- + >>> import mkl_fft + >>> mkl_fft.is_patched() + # False + + >>> with mkl_fft.mkl_fft(): # Enable mkl_fft in Numpy + >>> print(mkl_fft.is_patched()) + # True + + >>> mkl_fft.is_patched() + # False + + """ + + def __enter__(self): + patch_numpy_fft() + return self + + def __exit__(self, *exc): + restore_numpy_fft() + return False diff --git a/mkl_fft/interfaces/_float_utils.py b/mkl_fft/interfaces/_float_utils.py index 19e044d..af0f1da 100644 --- a/mkl_fft/interfaces/_float_utils.py +++ b/mkl_fft/interfaces/_float_utils.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/_numpy_fft.py b/mkl_fft/interfaces/_numpy_fft.py index d83fe0b..d752757 100644 --- a/mkl_fft/interfaces/_numpy_fft.py +++ b/mkl_fft/interfaces/_numpy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/_numpy_helper.py b/mkl_fft/interfaces/_numpy_helper.py index 1a67812..eeb154f 100644 --- a/mkl_fft/interfaces/_numpy_helper.py +++ b/mkl_fft/interfaces/_numpy_helper.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/_scipy_fft.py b/mkl_fft/interfaces/_scipy_fft.py index 64ccaf0..8938323 100644 --- a/mkl_fft/interfaces/_scipy_fft.py +++ b/mkl_fft/interfaces/_scipy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/numpy_fft.py b/mkl_fft/interfaces/numpy_fft.py index aa74f3d..6c36e65 100644 --- a/mkl_fft/interfaces/numpy_fft.py +++ b/mkl_fft/interfaces/numpy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/interfaces/scipy_fft.py b/mkl_fft/interfaces/scipy_fft.py index 4adce52..2e4d007 100644 --- a/mkl_fft/interfaces/scipy_fft.py +++ b/mkl_fft/interfaces/scipy_fft.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/tests/helper.py b/mkl_fft/tests/helper.py index 6a59664..77eff36 100644 --- a/mkl_fft/tests/helper.py +++ b/mkl_fft/tests/helper.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2025, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/tests/test_fft1d.py b/mkl_fft/tests/test_fft1d.py index 7a1e0fd..16464e7 100644 --- a/mkl_fft/tests/test_fft1d.py +++ b/mkl_fft/tests/test_fft1d.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py index 007f0bd..4f04a11 100644 --- a/mkl_fft/tests/test_fftnd.py +++ b/mkl_fft/tests/test_fftnd.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python # Copyright (c) 2017, Intel Corporation # # Redistribution and use in source and binary forms, with or without diff --git a/mkl_fft/tests/test_patch.py b/mkl_fft/tests/test_patch.py new file mode 100644 index 0000000..b68dc2d --- /dev/null +++ b/mkl_fft/tests/test_patch.py @@ -0,0 +1,80 @@ +# Copyright (c) 2017, Intel Corporation +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, +# this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of Intel Corporation nor the names of its contributors +# may be used to endorse or promote products derived from this software +# without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import numpy as np + +import mkl_fft +import mkl_fft.interfaces.numpy_fft as _nfft + + +def test_patch(): + old_module = np.fft.fft.__module__ + assert not mkl_fft.is_patched() + + mkl_fft.patch_numpy_fft() # Enable mkl_fft in Numpy + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + mkl_fft.restore_numpy_fft() # Disable mkl_fft in Numpy + assert not mkl_fft.is_patched() + assert np.fft.fft.__module__ == old_module + + +def test_patch_redundant_patching(): + old_module = np.fft.fft.__module__ + assert not mkl_fft.is_patched() + + mkl_fft.patch_numpy_fft() + mkl_fft.patch_numpy_fft() + + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + mkl_fft.restore_numpy_fft() + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + mkl_fft.restore_numpy_fft() + assert not mkl_fft.is_patched() + assert np.fft.fft.__module__ == old_module + + +def test_patch_reentrant(): + old_module = np.fft.fft.__module__ + assert not mkl_fft.is_patched() + + with mkl_fft.mkl_fft(): + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + with mkl_fft.mkl_fft(): + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + assert mkl_fft.is_patched() + assert np.fft.fft.__module__ == _nfft.fft.__module__ + + assert not mkl_fft.is_patched() + assert np.fft.fft.__module__ == old_module