diff --git a/README.rst b/README.rst index 9bf8f6aa..494a466e 100644 --- a/README.rst +++ b/README.rst @@ -71,7 +71,7 @@ convolution. Consider the following example: import torch import numpy as np import pywt - import ptwt # use "from src import ptwt" for a cloned the repo + import ptwt # generate an input of even length. data = np.array([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0]) @@ -133,7 +133,7 @@ Reconsidering the 1d case, try: import torch import pywt - import ptwt # use "from src import ptwt" for a cloned the repo + import ptwt # generate an input of even length. data = torch.arange(16, dtype=torch.float32) diff --git a/docs/citation.rst b/docs/citation.rst index 65376238..2d2729fa 100644 --- a/docs/citation.rst +++ b/docs/citation.rst @@ -7,17 +7,16 @@ If you use this work in a scientific context, please cite the following paper: .. code-block:: - @article{JMLR:v25:23-0636, - author = {Moritz Wolter and Felix Blanke and Jochen Garcke and Charles Tapley Hoyt}, - title = {ptwt - The PyTorch Wavelet Toolbox}, - journal = {Journal of Machine Learning Research}, - year = {2024}, - volume = {25}, - number = {80}, - pages = {1--7}, - url = {http://jmlr.org/papers/v25/23-0636.html} - } - + @article{JMLR:v25:23-0636, + author = {Moritz Wolter and Felix Blanke and Jochen Garcke and Charles Tapley Hoyt}, + title = {ptwt - The PyTorch Wavelet Toolbox}, + journal = {Journal of Machine Learning Research}, + year = {2024}, + volume = {25}, + number = {80}, + pages = {1--7}, + url = {http://jmlr.org/papers/v25/23-0636.html} + } This work builds upon `PyWavelets `_ please consider citing them as well. diff --git a/docs/examples.rst b/docs/examples.rst index b2bbf423..57c10e42 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -3,4 +3,5 @@ Wavelet transforms by example ============================= -Worked examples are available in the examples folder of the `GitHub repository `_ . +Worked examples are available in the examples folder of the `GitHub repository +`_ . diff --git a/docs/ref/index.rst b/docs/ref/index.rst index c15b8c5a..fe2245c7 100644 --- a/docs/ref/index.rst +++ b/docs/ref/index.rst @@ -4,17 +4,17 @@ The ptwt package -- API reference ================================= .. toctree:: - :maxdepth: 2 + :maxdepth: 2 - conv-fwt - conv-inverse-fwt - matrix-fwt - matrix-inverse-fwt - packets - stationary - cwt - return-types - boundary - wavelets-learnable - sparse-math - other + conv-fwt + conv-inverse-fwt + matrix-fwt + matrix-inverse-fwt + packets + stationary + cwt + return-types + boundary + wavelets-learnable + sparse-math + other diff --git a/docs/ref/matrix-fwt.rst b/docs/ref/matrix-fwt.rst index 2869507d..f469fc79 100644 --- a/docs/ref/matrix-fwt.rst +++ b/docs/ref/matrix-fwt.rst @@ -9,34 +9,35 @@ Sparse-matrix based Fast Wavelet Transform (FWT) --------------------------------------------- .. autoclass:: MatrixWavedec - :members: - :special-members: __call__ - :undoc-members: - :show-inheritance: + :members: + :special-members: __call__ + :undoc-members: + :show-inheritance: 2d decomposition using :class:`MatrixWavedec2` ---------------------------------------------- .. autoclass:: MatrixWavedec2 - :members: - :special-members: __call__ - :undoc-members: - :show-inheritance: + :members: + :special-members: __call__ + :undoc-members: + :show-inheritance: 3d decomposition using :class:`MatrixWavedec3` ---------------------------------------------- .. autoclass:: MatrixWavedec3 - :members: - :special-members: __call__ - :undoc-members: - :show-inheritance: - + :members: + :special-members: __call__ + :undoc-members: + :show-inheritance: Sparse-matrix FWT base class ---------------------------- -All sparse-matrix decomposition classes extend :class:`ptwt.matmul_transform.BaseMatrixWaveDec`. + +All sparse-matrix decomposition classes extend +:class:`ptwt.matmul_transform.BaseMatrixWaveDec`. .. autoclass:: ptwt.matmul_transform.BaseMatrixWaveDec - :members: - :undoc-members: + :members: + :undoc-members: diff --git a/docs/ref/matrix-inverse-fwt.rst b/docs/ref/matrix-inverse-fwt.rst index 6be49d5c..f8043885 100644 --- a/docs/ref/matrix-inverse-fwt.rst +++ b/docs/ref/matrix-inverse-fwt.rst @@ -9,22 +9,22 @@ Sparse-matrix based Inverse Fast Wavelet Transform (iFWT) --------------------------------------------- .. autoclass:: MatrixWaverec - :members: - :special-members: __call__ - :undoc-members: + :members: + :special-members: __call__ + :undoc-members: 2d reconstrucion using :class:`MatrixWaverec2` ---------------------------------------------- .. autoclass:: MatrixWaverec2 - :members: - :special-members: __call__ - :undoc-members: + :members: + :special-members: __call__ + :undoc-members: 3d reconstrucion using :class:`MatrixWaverec3` ---------------------------------------------- .. autoclass:: MatrixWaverec3 - :members: - :special-members: __call__ - :undoc-members: + :members: + :special-members: __call__ + :undoc-members: diff --git a/docs/ref/other.rst b/docs/ref/other.rst index f8a819e5..f7ce06c3 100644 --- a/docs/ref/other.rst +++ b/docs/ref/other.rst @@ -9,6 +9,6 @@ Version information ------------------- .. automodule:: ptwt.version - :members: - :undoc-members: - :show-inheritance: + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/ref/packets.rst b/docs/ref/packets.rst index fcd577fb..45cd8200 100644 --- a/docs/ref/packets.rst +++ b/docs/ref/packets.rst @@ -3,23 +3,23 @@ .. currentmodule:: ptwt Wavelet Packet Transform (WPT) -==================================== +============================== Packets in 1d using :class:`WaveletPacket` ------------------------------------------ .. autoclass:: WaveletPacket - :members: - :special-members: __getitem__ - :undoc-members: + :members: + :special-members: __getitem__ + :undoc-members: Packets in 2d using :class:`WaveletPacket2D` -------------------------------------------- .. autoclass:: WaveletPacket2D - :members: - :special-members: __getitem__ - :undoc-members: + :members: + :special-members: __getitem__ + :undoc-members: Node ordering ------------- diff --git a/docs/ref/return-types.rst b/docs/ref/return-types.rst index 04b51586..69f85156 100644 --- a/docs/ref/return-types.rst +++ b/docs/ref/return-types.rst @@ -10,7 +10,6 @@ Transforms in one dimension .. autoclass:: WaveletCoeff1d - Transforms in two dimensions ---------------------------- @@ -22,7 +21,6 @@ Transforms in two dimensions :show-inheritance: :member-order: bysource - Transforms in N dimensions -------------------------- diff --git a/docs/ref/sparse-math.rst b/docs/ref/sparse-math.rst index fba77432..a0fbac85 100644 --- a/docs/ref/sparse-math.rst +++ b/docs/ref/sparse-math.rst @@ -4,6 +4,6 @@ Sparse-matrix backend functions =============================== .. automodule:: ptwt.sparse_math - :members: - :undoc-members: - :show-inheritance: + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/ref/wavelets-learnable.rst b/docs/ref/wavelets-learnable.rst index 373caad5..20ff8c79 100644 --- a/docs/ref/wavelets-learnable.rst +++ b/docs/ref/wavelets-learnable.rst @@ -1,9 +1,9 @@ .. _ref-wavelets-learnable: Learnable adaptive wavelets -------------------------------- +=========================== .. automodule:: ptwt.wavelets_learnable - :members: - :undoc-members: - :show-inheritance: + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 96dcf32a..d6cacbac 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -3,4 +3,5 @@ Release Notes ============= -We publish releases via GitHub. The notes are available at the `GitHub release page `_ . +We publish releases via GitHub. The notes are available at the `GitHub release page +`_ . diff --git a/examples/deepfake_analysis/packet_plot.py b/examples/deepfake_analysis/packet_plot.py index dcc6baf7..580ea1ef 100644 --- a/examples/deepfake_analysis/packet_plot.py +++ b/examples/deepfake_analysis/packet_plot.py @@ -1,5 +1,4 @@ import os -from itertools import product import matplotlib.pyplot as plt import numpy as np diff --git a/examples/network_compression/mnist_compression.py b/examples/network_compression/mnist_compression.py index e1ad52bd..ff0fa6c2 100644 --- a/examples/network_compression/mnist_compression.py +++ b/examples/network_compression/mnist_compression.py @@ -239,7 +239,7 @@ def main(): ), batch_size=args.batch_size, shuffle=True, - **kwargs + **kwargs, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST( @@ -251,7 +251,7 @@ def main(): ), batch_size=args.test_batch_size, shuffle=True, - **kwargs + **kwargs, ) if args.compression == "Wavelet": diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py index 43bdac79..3a2ebbdb 100644 --- a/examples/network_compression/wavelet_linear.py +++ b/examples/network_compression/wavelet_linear.py @@ -70,9 +70,9 @@ def wavelet_analysis(self, x): c_lst = wavedec(x.unsqueeze(1), self.wavelet, level=self.scales) shape_lst = [c_el.shape[-1] for c_el in c_lst] c_tensor = torch.cat([c for c in c_lst], -1) - assert ( - shape_lst == self.coefficient_len_lst[::-1] - ), "Wavelet shape assumptions false. This is a bug." + assert shape_lst == self.coefficient_len_lst[::-1], ( + "Wavelet shape assumptions false. This is a bug." + ) return c_tensor def wavelet_reconstruction(self, x): diff --git a/examples/speed_tests/timeitconv_1d.py b/examples/speed_tests/timeitconv_1d.py index 18b317df..2b0ef507 100644 --- a/examples/speed_tests/timeitconv_1d.py +++ b/examples/speed_tests/timeitconv_1d.py @@ -1,5 +1,4 @@ import time -from typing import NamedTuple import matplotlib.pyplot as plt import numpy as np diff --git a/examples/speed_tests/timeitconv_2d.py b/examples/speed_tests/timeitconv_2d.py index f0e6c5dd..1dfadd66 100644 --- a/examples/speed_tests/timeitconv_2d.py +++ b/examples/speed_tests/timeitconv_2d.py @@ -1,5 +1,4 @@ import time -from typing import NamedTuple import matplotlib.pyplot as plt import numpy as np diff --git a/examples/speed_tests/timeitconv_2d_separable.py b/examples/speed_tests/timeitconv_2d_separable.py index 98cdc557..1533a312 100644 --- a/examples/speed_tests/timeitconv_2d_separable.py +++ b/examples/speed_tests/timeitconv_2d_separable.py @@ -1,5 +1,4 @@ import time -from typing import NamedTuple import matplotlib.pyplot as plt import numpy as np diff --git a/examples/speed_tests/timeitconv_3d.py b/examples/speed_tests/timeitconv_3d.py index 8a933186..7444e99e 100644 --- a/examples/speed_tests/timeitconv_3d.py +++ b/examples/speed_tests/timeitconv_3d.py @@ -1,5 +1,4 @@ import time -from typing import NamedTuple import matplotlib.pyplot as plt import numpy as np diff --git a/examples/wavelet_packet_chirp_analysis/chirp_analysis.py b/examples/wavelet_packet_chirp_analysis/chirp_analysis.py index 24617edd..4770c194 100644 --- a/examples/wavelet_packet_chirp_analysis/chirp_analysis.py +++ b/examples/wavelet_packet_chirp_analysis/chirp_analysis.py @@ -1,10 +1,8 @@ import matplotlib.pyplot as plt import numpy as np import pywt -import scipy.signal import torch -# use from src.ptwt.packets if you cloned the repo instead of using pip. from ptwt import WaveletPacket fs = 1000 diff --git a/noxfile.py b/noxfile.py index 89877589..35c223cc 100644 --- a/noxfile.py +++ b/noxfile.py @@ -17,6 +17,13 @@ def run_test_fast(session): session.run("pytest", "-m", "not slow") +@nox.session(name="doctests") +def run_doctests(session): + """Run tests in docstrings.""" + session.install(".[tests]", "xdoctest", "pygments", "matplotlib") + session.run("xdoctest", "-m", "ptwt", "--quiet") + + @nox.session(name="lint") def lint(session): """Check code conventions.""" diff --git a/src/ptwt/__init__.py b/src/ptwt/__init__.py index b10b4d9c..b306c629 100644 --- a/src/ptwt/__init__.py +++ b/src/ptwt/__init__.py @@ -19,3 +19,34 @@ from .packets import WaveletPacket, WaveletPacket2D from .separable_conv_transform import fswavedec2, fswavedec3, fswaverec2, fswaverec3 from .stationary_transform import iswt, swt + +__all__ = [ + "Wavelet", + "WaveletDetailTuple2d", + "WaveletCoeff2d", + "WaveletCoeff2dSeparable", + "WaveletCoeffNd", + "WaveletDetailDict", + "WaveletTensorTuple", + "cwt", + "wavedec", + "waverec", + "wavedec2", + "waverec2", + "wavedec3", + "waverec3", + "MatrixWavedec", + "MatrixWaverec", + "MatrixWavedec2", + "MatrixWaverec2", + "MatrixWavedec3", + "MatrixWaverec3", + "WaveletPacket", + "WaveletPacket2D", + "fswavedec2", + "fswavedec3", + "fswaverec2", + "fswaverec3", + "iswt", + "swt", +] diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py index c86c72cb..de668598 100644 --- a/src/ptwt/constants.py +++ b/src/ptwt/constants.py @@ -21,6 +21,7 @@ "WaveletCoeffNd", "WaveletDetailDict", "WaveletTensorTuple", + "WaveletCoeff1d", ] SUPPORTED_DTYPES = {torch.float32, torch.float64} diff --git a/src/ptwt/continuous_transform.py b/src/ptwt/continuous_transform.py index 8598905a..9bceb7c9 100644 --- a/src/ptwt/continuous_transform.py +++ b/src/ptwt/continuous_transform.py @@ -180,7 +180,8 @@ def _integrate_wavelet( for other wavelets, a tuple (int_psi_d, int_psi_r, x) is returned instead. Example: - >>> from pywt import Wavelet, _integrate_wavelet + >>> from pywt import Wavelet + >>> from ptwt.continuous_transform import _integrate_wavelet >>> wavelet1 = Wavelet('db2') >>> [int_psi, x] = _integrate_wavelet(wavelet1, precision=5) >>> wavelet2 = Wavelet('bior1.3') @@ -230,7 +231,8 @@ class _WaveletParameter(torch.nn.Parameter): class _DifferentiableContinuousWavelet( - torch.nn.Module, ContinuousWavelet # type: ignore + torch.nn.Module, + ContinuousWavelet, # type: ignore ): """A base class for learnable Continuous Wavelets.""" diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py index da5c2268..6c4b1551 100644 --- a/src/ptwt/packets.py +++ b/src/ptwt/packets.py @@ -83,14 +83,14 @@ def __init__( Args: data (torch.Tensor, optional): The input time series to transform. - By default the last axis is transformed. + By default, the last axis is transformed. Use the `axis` argument to choose another dimension. If None, the object is initialized without performing a decomposition. wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet. Refer to the output from ``pywt.wavelist(kind='discrete')`` for possible choices. - mode: The desired mode to handle signal boundaries. Select either the + mode: The desired mode to handle signal boundaries. Select either the sparse-matrix backend (``boundary``) or a padding mode. See :data:`ptwt.constants.ExtendedBoundaryMode`. Defaults to ``reflect``. @@ -112,6 +112,7 @@ def __init__( is not supported. Example: + >>> import torch, pywt, ptwt >>> import numpy as np >>> import scipy.signal @@ -123,7 +124,6 @@ def __init__( >>> np_lst = [wp[node] for node in wp.get_level(5)] >>> viz = np.stack(np_lst).squeeze() >>> plt.imshow(np.abs(viz)) - >>> plt.show() """ self.wavelet = _as_wavelet(wavelet) self.mode = mode diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py index f046ba87..60d1416a 100644 --- a/src/ptwt/separable_conv_transform.py +++ b/src/ptwt/separable_conv_transform.py @@ -31,7 +31,7 @@ ) from .conv_transform import wavedec, waverec -__all__ = ["fswavedec2", "fswavedec3"] +__all__ = ["fswavedec2", "fswavedec3", "fswaverec2", "fswaverec3"] def _separable_conv_dwtn_( diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py index 40f1b372..9dabf9e1 100644 --- a/src/ptwt/stationary_transform.py +++ b/src/ptwt/stationary_transform.py @@ -17,6 +17,8 @@ ) from .constants import Wavelet, WaveletCoeff1d +__all__ = ["iswt", "swt"] + def _circular_pad(x: torch.Tensor, padding_dimensions: Sequence[int]) -> torch.Tensor: """Pad a tensor in circular mode, more than once if needed.""" @@ -58,19 +60,20 @@ def swt( ) -> list[torch.Tensor]: """Compute a multilevel 1d stationary wavelet transform. - This fuctions is equivalent to pywt's :func:`pywt.swt` - with `trim_approx=True` and `norm=False`. + This fuctions is equivalent to pywt's :func:`pywt.swt` with `trim_approx=True` and + `norm=False`. Args: data (torch.Tensor): The input time series to transform. By default the last axis is transformed. + wavelet (Wavelet or str): A pywt wavelet compatible object or - the name of a pywt wavelet. - Refer to the output from ``pywt.wavelist(kind='discrete')`` - for possible choices. + the name of a pywt wavelet. Refer to the output from + ``pywt.wavelist(kind='discrete')`` for possible choices. + level (int, optional): The maximum decomposition level. - If None, the level is computed based on the signal shape. - Defaults to None. + If None, the level is computed based on the signal shape. Defaults to None. + axis (int): Compute the transform over this axis of the `data` tensor. Defaults to -1. @@ -115,8 +118,10 @@ def iswt( Args: coeffs: The wavelet coefficient sequence produced by the forward transform :func:`swt`. See :data:`ptwt.constants.WaveletCoeff1d`. + wavelet (Wavelet or str): A pywt wavelet compatible object or the name of a pywt wavelet, as used in the forward transform. + axis (int): Compute the transform over this axis of the `data` tensor. Defaults to -1. diff --git a/src/ptwt/wavelets_learnable.py b/src/ptwt/wavelets_learnable.py index c28b541b..48c5fb1c 100644 --- a/src/ptwt/wavelets_learnable.py +++ b/src/ptwt/wavelets_learnable.py @@ -12,9 +12,8 @@ class WaveletFilter(ABC): """Interface for learnable wavelets. - Each wavelet has a filter bank loss function - and comes with functionality that tests the perfect - reconstruction and anti-aliasing conditions. + Each wavelet has a filter bank loss function and comes with functionality that tests + the perfect reconstruction and antialiasing conditions. """ @property @@ -44,13 +43,12 @@ def pf_alias_cancellation_loss( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the product filter-alias cancellation loss. - See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ - Alternating sign convention from 0 to N see Strang overview - on the back of the cover. + See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ Alternating sign + convention from 0 to N see Strang overview on the back of the cover. Returns: - The numerical value of the alias cancellation loss, - as well as both loss components for analysis. + The numerical value of the alias cancellation loss, as well as both loss + components for analysis. """ dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) @@ -78,13 +76,12 @@ def alias_cancellation_loss( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return the alias cancellation loss. - Implementation of the ac-loss as described - on page 104 of Strang+Nguyen. + Implementation of the ac-loss as described on page 104 of Strang+Nguyen. $$F_0(z)H_0(-z) + F_1(z)H_1(-z) = 0$$ Returns: - The numerical value of the alias cancellation loss, - as well as both loss components for analysis. + The numerical value of the alias cancellation loss, as well as both loss + components for analysis. """ dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) @@ -120,8 +117,8 @@ def perfect_reconstruction_loss( """Return the perfect reconstruction loss. Returns: - The numerical value of the alias cancellation loss, - as well as both intermediate values for analysis. + The numerical value of the alias cancellation loss, as well as both + intermediate values for analysis. """ # Strang 107: Assuming alias cancellation holds: # P(z) = F(z)H(z) @@ -174,14 +171,14 @@ def __init__( dec_hi: torch.Tensor, rec_lo: torch.Tensor, rec_hi: torch.Tensor, - ): + ) -> None: """Create a Product filter object. Args: - dec_lo (torch.Tensor): Low pass analysis filter. - dec_hi (torch.Tensor): High pass analysis filter. - rec_lo (torch.Tensor): Low pass synthesis filter. - rec_hi (torch.Tensor): High pass synthesis filter. + dec_lo : Low pass analysis filter. + dec_hi : High pass analysis filter. + rec_lo : Low pass synthesis filter. + rec_hi : High pass synthesis filter. """ super().__init__() self.dec_lo = torch.nn.Parameter(dec_lo) @@ -223,29 +220,11 @@ def wavelet_loss(self) -> torch.Tensor: class SoftOrthogonalWavelet(ProductFilter, torch.nn.Module): """Orthogonal wavelets with a soft orthogonality constraint.""" - def __init__( - self, - dec_lo: torch.Tensor, - dec_hi: torch.Tensor, - rec_lo: torch.Tensor, - rec_hi: torch.Tensor, - ): - """Create a SoftOrthogonalWavelet object. - - Args: - dec_lo (torch.Tensor): Low pass analysis filter. - dec_hi (torch.Tensor): High pass analysis filter. - rec_lo (torch.Tensor): Low pass synthesis filter. - rec_hi (torch.Tensor): High pass synthesis filter. - """ - super().__init__(dec_lo, dec_hi, rec_lo, rec_hi) - def rec_lo_orthogonality_loss(self) -> torch.Tensor: """Return a Strang inspired soft orthogonality loss. - See Strang p. 148/149 or Harbo p. 80. - Since L is a convolution matrix, LL^T can be evaluated - trough convolution. + See Strang p. 148/149 or Harbo p. 80. Since L is a convolution matrix, LL^T can + be evaluated trough convolution. Returns: A tensor with the orthogonality constraint value. @@ -276,10 +255,9 @@ def rec_lo_orthogonality_loss(self) -> torch.Tensor: def filt_bank_orthogonality_loss(self) -> torch.Tensor: """Return a Jensen+Harbo inspired soft orthogonality loss. - On Page 79 of the Book Ripples in Mathematics - by Jensen la Cour-Harbo, the constraint - g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters - is presented. A measurement is implemented below. + On Page 79 of the Book Ripples in Mathematics by Jensen la Cour-Harbo, the + constraint g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters is + presented. A measurement is implemented below. Returns: A tensor with the orthogonality constraint value. diff --git a/tests/test_jit.py b/tests/test_jit.py index 419cbf53..7c9d4894 100644 --- a/tests/test_jit.py +++ b/tests/test_jit.py @@ -101,9 +101,7 @@ def test_conv_fwt_jit_2d() -> None: ) jit_ptcoeff = jit_wavedec2(data, wavelet) # unstack the lists. - jit_waverec = torch.jit.trace( - _to_jit_waverec_2, (jit_ptcoeff, wavelet) - ) # type: ignore + jit_waverec = torch.jit.trace(_to_jit_waverec_2, (jit_ptcoeff, wavelet)) # type: ignore rec = jit_waverec(jit_ptcoeff, wavelet) assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7) @@ -155,9 +153,7 @@ def test_conv_fwt_jit_3d() -> None: ) jit_ptcoeff = jit_wavedec3(data, wavelet) # unstack the lists. - jit_waverec = torch.jit.trace( - _to_jit_waverec_3, (jit_ptcoeff, wavelet) - ) # type: ignore + jit_waverec = torch.jit.trace(_to_jit_waverec_3, (jit_ptcoeff, wavelet)) # type: ignore rec = jit_waverec(jit_ptcoeff, wavelet) assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7) diff --git a/tests/test_sparse_math.py b/tests/test_sparse_math.py index bf26f750..bfdcbcda 100644 --- a/tests/test_sparse_math.py +++ b/tests/test_sparse_math.py @@ -22,8 +22,7 @@ def test_kron() -> None: """Test the implementation by evaluation. - The example is taken from - https://de.wikipedia.org/wiki/Kronecker-Produkt + The example is taken from https://de.wikipedia.org/wiki/Kronecker-Produkt """ a = torch.tensor([[1, 2], [3, 2], [5, 6]]).to_sparse() b = torch.tensor([[7, 8], [9, 0]]).to_sparse()