diff --git a/README.rst b/README.rst
index 9bf8f6aa..494a466e 100644
--- a/README.rst
+++ b/README.rst
@@ -71,7 +71,7 @@ convolution. Consider the following example:
import torch
import numpy as np
import pywt
- import ptwt # use "from src import ptwt" for a cloned the repo
+ import ptwt
# generate an input of even length.
data = np.array([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0])
@@ -133,7 +133,7 @@ Reconsidering the 1d case, try:
import torch
import pywt
- import ptwt # use "from src import ptwt" for a cloned the repo
+ import ptwt
# generate an input of even length.
data = torch.arange(16, dtype=torch.float32)
diff --git a/docs/citation.rst b/docs/citation.rst
index 65376238..2d2729fa 100644
--- a/docs/citation.rst
+++ b/docs/citation.rst
@@ -7,17 +7,16 @@ If you use this work in a scientific context, please cite the following paper:
.. code-block::
- @article{JMLR:v25:23-0636,
- author = {Moritz Wolter and Felix Blanke and Jochen Garcke and Charles Tapley Hoyt},
- title = {ptwt - The PyTorch Wavelet Toolbox},
- journal = {Journal of Machine Learning Research},
- year = {2024},
- volume = {25},
- number = {80},
- pages = {1--7},
- url = {http://jmlr.org/papers/v25/23-0636.html}
- }
-
+ @article{JMLR:v25:23-0636,
+ author = {Moritz Wolter and Felix Blanke and Jochen Garcke and Charles Tapley Hoyt},
+ title = {ptwt - The PyTorch Wavelet Toolbox},
+ journal = {Journal of Machine Learning Research},
+ year = {2024},
+ volume = {25},
+ number = {80},
+ pages = {1--7},
+ url = {http://jmlr.org/papers/v25/23-0636.html}
+ }
This work builds upon `PyWavelets `_
please consider citing them as well.
diff --git a/docs/examples.rst b/docs/examples.rst
index b2bbf423..57c10e42 100644
--- a/docs/examples.rst
+++ b/docs/examples.rst
@@ -3,4 +3,5 @@
Wavelet transforms by example
=============================
-Worked examples are available in the examples folder of the `GitHub repository `_ .
+Worked examples are available in the examples folder of the `GitHub repository
+`_ .
diff --git a/docs/ref/index.rst b/docs/ref/index.rst
index c15b8c5a..fe2245c7 100644
--- a/docs/ref/index.rst
+++ b/docs/ref/index.rst
@@ -4,17 +4,17 @@ The ptwt package -- API reference
=================================
.. toctree::
- :maxdepth: 2
+ :maxdepth: 2
- conv-fwt
- conv-inverse-fwt
- matrix-fwt
- matrix-inverse-fwt
- packets
- stationary
- cwt
- return-types
- boundary
- wavelets-learnable
- sparse-math
- other
+ conv-fwt
+ conv-inverse-fwt
+ matrix-fwt
+ matrix-inverse-fwt
+ packets
+ stationary
+ cwt
+ return-types
+ boundary
+ wavelets-learnable
+ sparse-math
+ other
diff --git a/docs/ref/matrix-fwt.rst b/docs/ref/matrix-fwt.rst
index 2869507d..f469fc79 100644
--- a/docs/ref/matrix-fwt.rst
+++ b/docs/ref/matrix-fwt.rst
@@ -9,34 +9,35 @@ Sparse-matrix based Fast Wavelet Transform (FWT)
---------------------------------------------
.. autoclass:: MatrixWavedec
- :members:
- :special-members: __call__
- :undoc-members:
- :show-inheritance:
+ :members:
+ :special-members: __call__
+ :undoc-members:
+ :show-inheritance:
2d decomposition using :class:`MatrixWavedec2`
----------------------------------------------
.. autoclass:: MatrixWavedec2
- :members:
- :special-members: __call__
- :undoc-members:
- :show-inheritance:
+ :members:
+ :special-members: __call__
+ :undoc-members:
+ :show-inheritance:
3d decomposition using :class:`MatrixWavedec3`
----------------------------------------------
.. autoclass:: MatrixWavedec3
- :members:
- :special-members: __call__
- :undoc-members:
- :show-inheritance:
-
+ :members:
+ :special-members: __call__
+ :undoc-members:
+ :show-inheritance:
Sparse-matrix FWT base class
----------------------------
-All sparse-matrix decomposition classes extend :class:`ptwt.matmul_transform.BaseMatrixWaveDec`.
+
+All sparse-matrix decomposition classes extend
+:class:`ptwt.matmul_transform.BaseMatrixWaveDec`.
.. autoclass:: ptwt.matmul_transform.BaseMatrixWaveDec
- :members:
- :undoc-members:
+ :members:
+ :undoc-members:
diff --git a/docs/ref/matrix-inverse-fwt.rst b/docs/ref/matrix-inverse-fwt.rst
index 6be49d5c..f8043885 100644
--- a/docs/ref/matrix-inverse-fwt.rst
+++ b/docs/ref/matrix-inverse-fwt.rst
@@ -9,22 +9,22 @@ Sparse-matrix based Inverse Fast Wavelet Transform (iFWT)
---------------------------------------------
.. autoclass:: MatrixWaverec
- :members:
- :special-members: __call__
- :undoc-members:
+ :members:
+ :special-members: __call__
+ :undoc-members:
2d reconstrucion using :class:`MatrixWaverec2`
----------------------------------------------
.. autoclass:: MatrixWaverec2
- :members:
- :special-members: __call__
- :undoc-members:
+ :members:
+ :special-members: __call__
+ :undoc-members:
3d reconstrucion using :class:`MatrixWaverec3`
----------------------------------------------
.. autoclass:: MatrixWaverec3
- :members:
- :special-members: __call__
- :undoc-members:
+ :members:
+ :special-members: __call__
+ :undoc-members:
diff --git a/docs/ref/other.rst b/docs/ref/other.rst
index f8a819e5..f7ce06c3 100644
--- a/docs/ref/other.rst
+++ b/docs/ref/other.rst
@@ -9,6 +9,6 @@ Version information
-------------------
.. automodule:: ptwt.version
- :members:
- :undoc-members:
- :show-inheritance:
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/ref/packets.rst b/docs/ref/packets.rst
index fcd577fb..45cd8200 100644
--- a/docs/ref/packets.rst
+++ b/docs/ref/packets.rst
@@ -3,23 +3,23 @@
.. currentmodule:: ptwt
Wavelet Packet Transform (WPT)
-====================================
+==============================
Packets in 1d using :class:`WaveletPacket`
------------------------------------------
.. autoclass:: WaveletPacket
- :members:
- :special-members: __getitem__
- :undoc-members:
+ :members:
+ :special-members: __getitem__
+ :undoc-members:
Packets in 2d using :class:`WaveletPacket2D`
--------------------------------------------
.. autoclass:: WaveletPacket2D
- :members:
- :special-members: __getitem__
- :undoc-members:
+ :members:
+ :special-members: __getitem__
+ :undoc-members:
Node ordering
-------------
diff --git a/docs/ref/return-types.rst b/docs/ref/return-types.rst
index 04b51586..69f85156 100644
--- a/docs/ref/return-types.rst
+++ b/docs/ref/return-types.rst
@@ -10,7 +10,6 @@ Transforms in one dimension
.. autoclass:: WaveletCoeff1d
-
Transforms in two dimensions
----------------------------
@@ -22,7 +21,6 @@ Transforms in two dimensions
:show-inheritance:
:member-order: bysource
-
Transforms in N dimensions
--------------------------
diff --git a/docs/ref/sparse-math.rst b/docs/ref/sparse-math.rst
index fba77432..a0fbac85 100644
--- a/docs/ref/sparse-math.rst
+++ b/docs/ref/sparse-math.rst
@@ -4,6 +4,6 @@ Sparse-matrix backend functions
===============================
.. automodule:: ptwt.sparse_math
- :members:
- :undoc-members:
- :show-inheritance:
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/ref/wavelets-learnable.rst b/docs/ref/wavelets-learnable.rst
index 373caad5..20ff8c79 100644
--- a/docs/ref/wavelets-learnable.rst
+++ b/docs/ref/wavelets-learnable.rst
@@ -1,9 +1,9 @@
.. _ref-wavelets-learnable:
Learnable adaptive wavelets
--------------------------------
+===========================
.. automodule:: ptwt.wavelets_learnable
- :members:
- :undoc-members:
- :show-inheritance:
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/release_notes.rst b/docs/release_notes.rst
index 96dcf32a..d6cacbac 100644
--- a/docs/release_notes.rst
+++ b/docs/release_notes.rst
@@ -3,4 +3,5 @@
Release Notes
=============
-We publish releases via GitHub. The notes are available at the `GitHub release page `_ .
+We publish releases via GitHub. The notes are available at the `GitHub release page
+`_ .
diff --git a/examples/deepfake_analysis/packet_plot.py b/examples/deepfake_analysis/packet_plot.py
index dcc6baf7..580ea1ef 100644
--- a/examples/deepfake_analysis/packet_plot.py
+++ b/examples/deepfake_analysis/packet_plot.py
@@ -1,5 +1,4 @@
import os
-from itertools import product
import matplotlib.pyplot as plt
import numpy as np
diff --git a/examples/network_compression/mnist_compression.py b/examples/network_compression/mnist_compression.py
index e1ad52bd..ff0fa6c2 100644
--- a/examples/network_compression/mnist_compression.py
+++ b/examples/network_compression/mnist_compression.py
@@ -239,7 +239,7 @@ def main():
),
batch_size=args.batch_size,
shuffle=True,
- **kwargs
+ **kwargs,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
@@ -251,7 +251,7 @@ def main():
),
batch_size=args.test_batch_size,
shuffle=True,
- **kwargs
+ **kwargs,
)
if args.compression == "Wavelet":
diff --git a/examples/network_compression/wavelet_linear.py b/examples/network_compression/wavelet_linear.py
index 43bdac79..3a2ebbdb 100644
--- a/examples/network_compression/wavelet_linear.py
+++ b/examples/network_compression/wavelet_linear.py
@@ -70,9 +70,9 @@ def wavelet_analysis(self, x):
c_lst = wavedec(x.unsqueeze(1), self.wavelet, level=self.scales)
shape_lst = [c_el.shape[-1] for c_el in c_lst]
c_tensor = torch.cat([c for c in c_lst], -1)
- assert (
- shape_lst == self.coefficient_len_lst[::-1]
- ), "Wavelet shape assumptions false. This is a bug."
+ assert shape_lst == self.coefficient_len_lst[::-1], (
+ "Wavelet shape assumptions false. This is a bug."
+ )
return c_tensor
def wavelet_reconstruction(self, x):
diff --git a/examples/speed_tests/timeitconv_1d.py b/examples/speed_tests/timeitconv_1d.py
index 18b317df..2b0ef507 100644
--- a/examples/speed_tests/timeitconv_1d.py
+++ b/examples/speed_tests/timeitconv_1d.py
@@ -1,5 +1,4 @@
import time
-from typing import NamedTuple
import matplotlib.pyplot as plt
import numpy as np
diff --git a/examples/speed_tests/timeitconv_2d.py b/examples/speed_tests/timeitconv_2d.py
index f0e6c5dd..1dfadd66 100644
--- a/examples/speed_tests/timeitconv_2d.py
+++ b/examples/speed_tests/timeitconv_2d.py
@@ -1,5 +1,4 @@
import time
-from typing import NamedTuple
import matplotlib.pyplot as plt
import numpy as np
diff --git a/examples/speed_tests/timeitconv_2d_separable.py b/examples/speed_tests/timeitconv_2d_separable.py
index 98cdc557..1533a312 100644
--- a/examples/speed_tests/timeitconv_2d_separable.py
+++ b/examples/speed_tests/timeitconv_2d_separable.py
@@ -1,5 +1,4 @@
import time
-from typing import NamedTuple
import matplotlib.pyplot as plt
import numpy as np
diff --git a/examples/speed_tests/timeitconv_3d.py b/examples/speed_tests/timeitconv_3d.py
index 8a933186..7444e99e 100644
--- a/examples/speed_tests/timeitconv_3d.py
+++ b/examples/speed_tests/timeitconv_3d.py
@@ -1,5 +1,4 @@
import time
-from typing import NamedTuple
import matplotlib.pyplot as plt
import numpy as np
diff --git a/examples/wavelet_packet_chirp_analysis/chirp_analysis.py b/examples/wavelet_packet_chirp_analysis/chirp_analysis.py
index 24617edd..4770c194 100644
--- a/examples/wavelet_packet_chirp_analysis/chirp_analysis.py
+++ b/examples/wavelet_packet_chirp_analysis/chirp_analysis.py
@@ -1,10 +1,8 @@
import matplotlib.pyplot as plt
import numpy as np
import pywt
-import scipy.signal
import torch
-# use from src.ptwt.packets if you cloned the repo instead of using pip.
from ptwt import WaveletPacket
fs = 1000
diff --git a/noxfile.py b/noxfile.py
index 89877589..35c223cc 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -17,6 +17,13 @@ def run_test_fast(session):
session.run("pytest", "-m", "not slow")
+@nox.session(name="doctests")
+def run_doctests(session):
+ """Run tests in docstrings."""
+ session.install(".[tests]", "xdoctest", "pygments", "matplotlib")
+ session.run("xdoctest", "-m", "ptwt", "--quiet")
+
+
@nox.session(name="lint")
def lint(session):
"""Check code conventions."""
diff --git a/src/ptwt/__init__.py b/src/ptwt/__init__.py
index b10b4d9c..b306c629 100644
--- a/src/ptwt/__init__.py
+++ b/src/ptwt/__init__.py
@@ -19,3 +19,34 @@
from .packets import WaveletPacket, WaveletPacket2D
from .separable_conv_transform import fswavedec2, fswavedec3, fswaverec2, fswaverec3
from .stationary_transform import iswt, swt
+
+__all__ = [
+ "Wavelet",
+ "WaveletDetailTuple2d",
+ "WaveletCoeff2d",
+ "WaveletCoeff2dSeparable",
+ "WaveletCoeffNd",
+ "WaveletDetailDict",
+ "WaveletTensorTuple",
+ "cwt",
+ "wavedec",
+ "waverec",
+ "wavedec2",
+ "waverec2",
+ "wavedec3",
+ "waverec3",
+ "MatrixWavedec",
+ "MatrixWaverec",
+ "MatrixWavedec2",
+ "MatrixWaverec2",
+ "MatrixWavedec3",
+ "MatrixWaverec3",
+ "WaveletPacket",
+ "WaveletPacket2D",
+ "fswavedec2",
+ "fswavedec3",
+ "fswaverec2",
+ "fswaverec3",
+ "iswt",
+ "swt",
+]
diff --git a/src/ptwt/constants.py b/src/ptwt/constants.py
index c86c72cb..de668598 100644
--- a/src/ptwt/constants.py
+++ b/src/ptwt/constants.py
@@ -21,6 +21,7 @@
"WaveletCoeffNd",
"WaveletDetailDict",
"WaveletTensorTuple",
+ "WaveletCoeff1d",
]
SUPPORTED_DTYPES = {torch.float32, torch.float64}
diff --git a/src/ptwt/continuous_transform.py b/src/ptwt/continuous_transform.py
index 8598905a..9bceb7c9 100644
--- a/src/ptwt/continuous_transform.py
+++ b/src/ptwt/continuous_transform.py
@@ -180,7 +180,8 @@ def _integrate_wavelet(
for other wavelets, a tuple (int_psi_d, int_psi_r, x) is returned instead.
Example:
- >>> from pywt import Wavelet, _integrate_wavelet
+ >>> from pywt import Wavelet
+ >>> from ptwt.continuous_transform import _integrate_wavelet
>>> wavelet1 = Wavelet('db2')
>>> [int_psi, x] = _integrate_wavelet(wavelet1, precision=5)
>>> wavelet2 = Wavelet('bior1.3')
@@ -230,7 +231,8 @@ class _WaveletParameter(torch.nn.Parameter):
class _DifferentiableContinuousWavelet(
- torch.nn.Module, ContinuousWavelet # type: ignore
+ torch.nn.Module,
+ ContinuousWavelet, # type: ignore
):
"""A base class for learnable Continuous Wavelets."""
diff --git a/src/ptwt/packets.py b/src/ptwt/packets.py
index da5c2268..6c4b1551 100644
--- a/src/ptwt/packets.py
+++ b/src/ptwt/packets.py
@@ -83,14 +83,14 @@ def __init__(
Args:
data (torch.Tensor, optional): The input time series to transform.
- By default the last axis is transformed.
+ By default, the last axis is transformed.
Use the `axis` argument to choose another dimension.
If None, the object is initialized without performing a decomposition.
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet.
Refer to the output from ``pywt.wavelist(kind='discrete')``
for possible choices.
- mode: The desired mode to handle signal boundaries. Select either the
+ mode: The desired mode to handle signal boundaries. Select either
the sparse-matrix backend (``boundary``) or a padding mode.
See :data:`ptwt.constants.ExtendedBoundaryMode`.
Defaults to ``reflect``.
@@ -112,6 +112,7 @@ def __init__(
is not supported.
Example:
+
>>> import torch, pywt, ptwt
>>> import numpy as np
>>> import scipy.signal
@@ -123,7 +124,6 @@ def __init__(
>>> np_lst = [wp[node] for node in wp.get_level(5)]
>>> viz = np.stack(np_lst).squeeze()
>>> plt.imshow(np.abs(viz))
- >>> plt.show()
"""
self.wavelet = _as_wavelet(wavelet)
self.mode = mode
diff --git a/src/ptwt/separable_conv_transform.py b/src/ptwt/separable_conv_transform.py
index f046ba87..60d1416a 100644
--- a/src/ptwt/separable_conv_transform.py
+++ b/src/ptwt/separable_conv_transform.py
@@ -31,7 +31,7 @@
)
from .conv_transform import wavedec, waverec
-__all__ = ["fswavedec2", "fswavedec3"]
+__all__ = ["fswavedec2", "fswavedec3", "fswaverec2", "fswaverec3"]
def _separable_conv_dwtn_(
diff --git a/src/ptwt/stationary_transform.py b/src/ptwt/stationary_transform.py
index 40f1b372..9dabf9e1 100644
--- a/src/ptwt/stationary_transform.py
+++ b/src/ptwt/stationary_transform.py
@@ -17,6 +17,8 @@
)
from .constants import Wavelet, WaveletCoeff1d
+__all__ = ["iswt", "swt"]
+
def _circular_pad(x: torch.Tensor, padding_dimensions: Sequence[int]) -> torch.Tensor:
"""Pad a tensor in circular mode, more than once if needed."""
@@ -58,19 +60,20 @@ def swt(
) -> list[torch.Tensor]:
"""Compute a multilevel 1d stationary wavelet transform.
- This fuctions is equivalent to pywt's :func:`pywt.swt`
- with `trim_approx=True` and `norm=False`.
+ This fuctions is equivalent to pywt's :func:`pywt.swt` with `trim_approx=True` and
+ `norm=False`.
Args:
data (torch.Tensor): The input time series to transform.
By default the last axis is transformed.
+
wavelet (Wavelet or str): A pywt wavelet compatible object or
- the name of a pywt wavelet.
- Refer to the output from ``pywt.wavelist(kind='discrete')``
- for possible choices.
+ the name of a pywt wavelet. Refer to the output from
+ ``pywt.wavelist(kind='discrete')`` for possible choices.
+
level (int, optional): The maximum decomposition level.
- If None, the level is computed based on the signal shape.
- Defaults to None.
+ If None, the level is computed based on the signal shape. Defaults to None.
+
axis (int): Compute the transform over this axis of the `data` tensor.
Defaults to -1.
@@ -115,8 +118,10 @@ def iswt(
Args:
coeffs: The wavelet coefficient sequence produced by the forward transform
:func:`swt`. See :data:`ptwt.constants.WaveletCoeff1d`.
+
wavelet (Wavelet or str): A pywt wavelet compatible object or
the name of a pywt wavelet, as used in the forward transform.
+
axis (int): Compute the transform over this axis of the `data` tensor.
Defaults to -1.
diff --git a/src/ptwt/wavelets_learnable.py b/src/ptwt/wavelets_learnable.py
index c28b541b..48c5fb1c 100644
--- a/src/ptwt/wavelets_learnable.py
+++ b/src/ptwt/wavelets_learnable.py
@@ -12,9 +12,8 @@
class WaveletFilter(ABC):
"""Interface for learnable wavelets.
- Each wavelet has a filter bank loss function
- and comes with functionality that tests the perfect
- reconstruction and anti-aliasing conditions.
+ Each wavelet has a filter bank loss function and comes with functionality that tests
+ the perfect reconstruction and antialiasing conditions.
"""
@property
@@ -44,13 +43,12 @@ def pf_alias_cancellation_loss(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return the product filter-alias cancellation loss.
- See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$
- Alternating sign convention from 0 to N see Strang overview
- on the back of the cover.
+ See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ Alternating sign
+ convention from 0 to N see Strang overview on the back of the cover.
Returns:
- The numerical value of the alias cancellation loss,
- as well as both loss components for analysis.
+ The numerical value of the alias cancellation loss, as well as both loss
+ components for analysis.
"""
dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank
m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype)
@@ -78,13 +76,12 @@ def alias_cancellation_loss(
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Return the alias cancellation loss.
- Implementation of the ac-loss as described
- on page 104 of Strang+Nguyen.
+ Implementation of the ac-loss as described on page 104 of Strang+Nguyen.
$$F_0(z)H_0(-z) + F_1(z)H_1(-z) = 0$$
Returns:
- The numerical value of the alias cancellation loss,
- as well as both loss components for analysis.
+ The numerical value of the alias cancellation loss, as well as both loss
+ components for analysis.
"""
dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank
m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype)
@@ -120,8 +117,8 @@ def perfect_reconstruction_loss(
"""Return the perfect reconstruction loss.
Returns:
- The numerical value of the alias cancellation loss,
- as well as both intermediate values for analysis.
+ The numerical value of the alias cancellation loss, as well as both
+ intermediate values for analysis.
"""
# Strang 107: Assuming alias cancellation holds:
# P(z) = F(z)H(z)
@@ -174,14 +171,14 @@ def __init__(
dec_hi: torch.Tensor,
rec_lo: torch.Tensor,
rec_hi: torch.Tensor,
- ):
+ ) -> None:
"""Create a Product filter object.
Args:
- dec_lo (torch.Tensor): Low pass analysis filter.
- dec_hi (torch.Tensor): High pass analysis filter.
- rec_lo (torch.Tensor): Low pass synthesis filter.
- rec_hi (torch.Tensor): High pass synthesis filter.
+ dec_lo : Low pass analysis filter.
+ dec_hi : High pass analysis filter.
+ rec_lo : Low pass synthesis filter.
+ rec_hi : High pass synthesis filter.
"""
super().__init__()
self.dec_lo = torch.nn.Parameter(dec_lo)
@@ -223,29 +220,11 @@ def wavelet_loss(self) -> torch.Tensor:
class SoftOrthogonalWavelet(ProductFilter, torch.nn.Module):
"""Orthogonal wavelets with a soft orthogonality constraint."""
- def __init__(
- self,
- dec_lo: torch.Tensor,
- dec_hi: torch.Tensor,
- rec_lo: torch.Tensor,
- rec_hi: torch.Tensor,
- ):
- """Create a SoftOrthogonalWavelet object.
-
- Args:
- dec_lo (torch.Tensor): Low pass analysis filter.
- dec_hi (torch.Tensor): High pass analysis filter.
- rec_lo (torch.Tensor): Low pass synthesis filter.
- rec_hi (torch.Tensor): High pass synthesis filter.
- """
- super().__init__(dec_lo, dec_hi, rec_lo, rec_hi)
-
def rec_lo_orthogonality_loss(self) -> torch.Tensor:
"""Return a Strang inspired soft orthogonality loss.
- See Strang p. 148/149 or Harbo p. 80.
- Since L is a convolution matrix, LL^T can be evaluated
- trough convolution.
+ See Strang p. 148/149 or Harbo p. 80. Since L is a convolution matrix, LL^T can
+ be evaluated trough convolution.
Returns:
A tensor with the orthogonality constraint value.
@@ -276,10 +255,9 @@ def rec_lo_orthogonality_loss(self) -> torch.Tensor:
def filt_bank_orthogonality_loss(self) -> torch.Tensor:
"""Return a Jensen+Harbo inspired soft orthogonality loss.
- On Page 79 of the Book Ripples in Mathematics
- by Jensen la Cour-Harbo, the constraint
- g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters
- is presented. A measurement is implemented below.
+ On Page 79 of the Book Ripples in Mathematics by Jensen la Cour-Harbo, the
+ constraint g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters is
+ presented. A measurement is implemented below.
Returns:
A tensor with the orthogonality constraint value.
diff --git a/tests/test_jit.py b/tests/test_jit.py
index 419cbf53..7c9d4894 100644
--- a/tests/test_jit.py
+++ b/tests/test_jit.py
@@ -101,9 +101,7 @@ def test_conv_fwt_jit_2d() -> None:
)
jit_ptcoeff = jit_wavedec2(data, wavelet)
# unstack the lists.
- jit_waverec = torch.jit.trace(
- _to_jit_waverec_2, (jit_ptcoeff, wavelet)
- ) # type: ignore
+ jit_waverec = torch.jit.trace(_to_jit_waverec_2, (jit_ptcoeff, wavelet)) # type: ignore
rec = jit_waverec(jit_ptcoeff, wavelet)
assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7)
@@ -155,9 +153,7 @@ def test_conv_fwt_jit_3d() -> None:
)
jit_ptcoeff = jit_wavedec3(data, wavelet)
# unstack the lists.
- jit_waverec = torch.jit.trace(
- _to_jit_waverec_3, (jit_ptcoeff, wavelet)
- ) # type: ignore
+ jit_waverec = torch.jit.trace(_to_jit_waverec_3, (jit_ptcoeff, wavelet)) # type: ignore
rec = jit_waverec(jit_ptcoeff, wavelet)
assert np.allclose(rec.squeeze(1).numpy(), data.numpy(), atol=1e-7)
diff --git a/tests/test_sparse_math.py b/tests/test_sparse_math.py
index bf26f750..bfdcbcda 100644
--- a/tests/test_sparse_math.py
+++ b/tests/test_sparse_math.py
@@ -22,8 +22,7 @@
def test_kron() -> None:
"""Test the implementation by evaluation.
- The example is taken from
- https://de.wikipedia.org/wiki/Kronecker-Produkt
+ The example is taken from https://de.wikipedia.org/wiki/Kronecker-Produkt
"""
a = torch.tensor([[1, 2], [3, 2], [5, 6]]).to_sparse()
b = torch.tensor([[7, 8], [9, 0]]).to_sparse()