Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ convolution. Consider the following example:
import torch
import numpy as np
import pywt
import ptwt # use "from src import ptwt" for a cloned the repo
import ptwt

# generate an input of even length.
data = np.array([0, 1, 2, 3, 4, 5, 6, 7, 7, 6, 5, 4, 3, 2, 1, 0])
Expand Down Expand Up @@ -133,7 +133,7 @@ Reconsidering the 1d case, try:

import torch
import pywt
import ptwt # use "from src import ptwt" for a cloned the repo
import ptwt

# generate an input of even length.
data = torch.arange(16, dtype=torch.float32)
Expand Down
21 changes: 10 additions & 11 deletions docs/citation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,16 @@ If you use this work in a scientific context, please cite the following paper:

.. code-block::

@article{JMLR:v25:23-0636,
author = {Moritz Wolter and Felix Blanke and Jochen Garcke and Charles Tapley Hoyt},
title = {ptwt - The PyTorch Wavelet Toolbox},
journal = {Journal of Machine Learning Research},
year = {2024},
volume = {25},
number = {80},
pages = {1--7},
url = {http://jmlr.org/papers/v25/23-0636.html}
}

@article{JMLR:v25:23-0636,
author = {Moritz Wolter and Felix Blanke and Jochen Garcke and Charles Tapley Hoyt},
title = {ptwt - The PyTorch Wavelet Toolbox},
journal = {Journal of Machine Learning Research},
year = {2024},
volume = {25},
number = {80},
pages = {1--7},
url = {http://jmlr.org/papers/v25/23-0636.html}
}

This work builds upon `PyWavelets <https://pywavelets.readthedocs.io/en/latest/>`_
please consider citing them as well.
3 changes: 2 additions & 1 deletion docs/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
Wavelet transforms by example
=============================

Worked examples are available in the examples folder of the `GitHub repository <https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/main/examples>`_ .
Worked examples are available in the examples folder of the `GitHub repository
<https://github.com/v0lta/PyTorch-Wavelet-Toolbox/tree/main/examples>`_ .
26 changes: 13 additions & 13 deletions docs/ref/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@ The ptwt package -- API reference
=================================

.. toctree::
:maxdepth: 2
:maxdepth: 2

conv-fwt
conv-inverse-fwt
matrix-fwt
matrix-inverse-fwt
packets
stationary
cwt
return-types
boundary
wavelets-learnable
sparse-math
other
conv-fwt
conv-inverse-fwt
matrix-fwt
matrix-inverse-fwt
packets
stationary
cwt
return-types
boundary
wavelets-learnable
sparse-math
other
33 changes: 17 additions & 16 deletions docs/ref/matrix-fwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,35 @@ Sparse-matrix based Fast Wavelet Transform (FWT)
---------------------------------------------

.. autoclass:: MatrixWavedec
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:

2d decomposition using :class:`MatrixWavedec2`
----------------------------------------------

.. autoclass:: MatrixWavedec2
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:

3d decomposition using :class:`MatrixWavedec3`
----------------------------------------------

.. autoclass:: MatrixWavedec3
:members:
:special-members: __call__
:undoc-members:
:show-inheritance:

:members:
:special-members: __call__
:undoc-members:
:show-inheritance:

Sparse-matrix FWT base class
----------------------------
All sparse-matrix decomposition classes extend :class:`ptwt.matmul_transform.BaseMatrixWaveDec`.

All sparse-matrix decomposition classes extend
:class:`ptwt.matmul_transform.BaseMatrixWaveDec`.

.. autoclass:: ptwt.matmul_transform.BaseMatrixWaveDec
:members:
:undoc-members:
:members:
:undoc-members:
18 changes: 9 additions & 9 deletions docs/ref/matrix-inverse-fwt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ Sparse-matrix based Inverse Fast Wavelet Transform (iFWT)
---------------------------------------------

.. autoclass:: MatrixWaverec
:members:
:special-members: __call__
:undoc-members:
:members:
:special-members: __call__
:undoc-members:

2d reconstrucion using :class:`MatrixWaverec2`
----------------------------------------------

.. autoclass:: MatrixWaverec2
:members:
:special-members: __call__
:undoc-members:
:members:
:special-members: __call__
:undoc-members:

3d reconstrucion using :class:`MatrixWaverec3`
----------------------------------------------

.. autoclass:: MatrixWaverec3
:members:
:special-members: __call__
:undoc-members:
:members:
:special-members: __call__
:undoc-members:
6 changes: 3 additions & 3 deletions docs/ref/other.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,6 @@ Version information
-------------------

.. automodule:: ptwt.version
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
14 changes: 7 additions & 7 deletions docs/ref/packets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
.. currentmodule:: ptwt

Wavelet Packet Transform (WPT)
====================================
==============================

Packets in 1d using :class:`WaveletPacket`
------------------------------------------

.. autoclass:: WaveletPacket
:members:
:special-members: __getitem__
:undoc-members:
:members:
:special-members: __getitem__
:undoc-members:

Packets in 2d using :class:`WaveletPacket2D`
--------------------------------------------

.. autoclass:: WaveletPacket2D
:members:
:special-members: __getitem__
:undoc-members:
:members:
:special-members: __getitem__
:undoc-members:

Node ordering
-------------
Expand Down
2 changes: 0 additions & 2 deletions docs/ref/return-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ Transforms in one dimension

.. autoclass:: WaveletCoeff1d


Transforms in two dimensions
----------------------------

Expand All @@ -22,7 +21,6 @@ Transforms in two dimensions
:show-inheritance:
:member-order: bysource


Transforms in N dimensions
--------------------------

Expand Down
6 changes: 3 additions & 3 deletions docs/ref/sparse-math.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@ Sparse-matrix backend functions
===============================

.. automodule:: ptwt.sparse_math
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
8 changes: 4 additions & 4 deletions docs/ref/wavelets-learnable.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
.. _ref-wavelets-learnable:

Learnable adaptive wavelets
-------------------------------
===========================

.. automodule:: ptwt.wavelets_learnable
:members:
:undoc-members:
:show-inheritance:
:members:
:undoc-members:
:show-inheritance:
3 changes: 2 additions & 1 deletion docs/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
Release Notes
=============

We publish releases via GitHub. The notes are available at the `GitHub release page <https://github.com/v0lta/PyTorch-Wavelet-Toolbox/releases>`_ .
We publish releases via GitHub. The notes are available at the `GitHub release page
<https://github.com/v0lta/PyTorch-Wavelet-Toolbox/releases>`_ .
1 change: 0 additions & 1 deletion examples/deepfake_analysis/packet_plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import os
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
Expand Down
4 changes: 2 additions & 2 deletions examples/network_compression/mnist_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def main():
),
batch_size=args.batch_size,
shuffle=True,
**kwargs
**kwargs,
)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(
Expand All @@ -251,7 +251,7 @@ def main():
),
batch_size=args.test_batch_size,
shuffle=True,
**kwargs
**kwargs,
)

if args.compression == "Wavelet":
Expand Down
6 changes: 3 additions & 3 deletions examples/network_compression/wavelet_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ def wavelet_analysis(self, x):
c_lst = wavedec(x.unsqueeze(1), self.wavelet, level=self.scales)
shape_lst = [c_el.shape[-1] for c_el in c_lst]
c_tensor = torch.cat([c for c in c_lst], -1)
assert (
shape_lst == self.coefficient_len_lst[::-1]
), "Wavelet shape assumptions false. This is a bug."
assert shape_lst == self.coefficient_len_lst[::-1], (
"Wavelet shape assumptions false. This is a bug."
)
return c_tensor

def wavelet_reconstruction(self, x):
Expand Down
1 change: 0 additions & 1 deletion examples/speed_tests/timeitconv_1d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion examples/speed_tests/timeitconv_2d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion examples/speed_tests/timeitconv_2d_separable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
Expand Down
1 change: 0 additions & 1 deletion examples/speed_tests/timeitconv_3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import time
from typing import NamedTuple

import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 0 additions & 2 deletions examples/wavelet_packet_chirp_analysis/chirp_analysis.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import matplotlib.pyplot as plt
import numpy as np
import pywt
import scipy.signal
import torch

# use from src.ptwt.packets if you cloned the repo instead of using pip.
from ptwt import WaveletPacket

fs = 1000
Expand Down
7 changes: 7 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def run_test_fast(session):
session.run("pytest", "-m", "not slow")


@nox.session(name="doctests")
def run_doctests(session):
"""Run tests in docstrings."""
session.install(".[tests]", "xdoctest", "pygments", "matplotlib")
session.run("xdoctest", "-m", "ptwt", "--quiet")


@nox.session(name="lint")
def lint(session):
"""Check code conventions."""
Expand Down
31 changes: 31 additions & 0 deletions src/ptwt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,34 @@
from .packets import WaveletPacket, WaveletPacket2D
from .separable_conv_transform import fswavedec2, fswavedec3, fswaverec2, fswaverec3
from .stationary_transform import iswt, swt

__all__ = [
"Wavelet",
"WaveletDetailTuple2d",
"WaveletCoeff2d",
"WaveletCoeff2dSeparable",
"WaveletCoeffNd",
"WaveletDetailDict",
"WaveletTensorTuple",
"cwt",
"wavedec",
"waverec",
"wavedec2",
"waverec2",
"wavedec3",
"waverec3",
"MatrixWavedec",
"MatrixWaverec",
"MatrixWavedec2",
"MatrixWaverec2",
"MatrixWavedec3",
"MatrixWaverec3",
"WaveletPacket",
"WaveletPacket2D",
"fswavedec2",
"fswavedec3",
"fswaverec2",
"fswaverec3",
"iswt",
"swt",
]
1 change: 1 addition & 0 deletions src/ptwt/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"WaveletCoeffNd",
"WaveletDetailDict",
"WaveletTensorTuple",
"WaveletCoeff1d",
]

SUPPORTED_DTYPES = {torch.float32, torch.float64}
Expand Down
Loading