|
12 | 12 | class WaveletFilter(ABC): |
13 | 13 | """Interface for learnable wavelets. |
14 | 14 |
|
15 | | - Each wavelet has a filter bank loss function |
16 | | - and comes with functionality that tests the perfect |
17 | | - reconstruction and anti-aliasing conditions. |
| 15 | + Each wavelet has a filter bank loss function and comes with functionality that tests |
| 16 | + the perfect reconstruction and antialiasing conditions. |
18 | 17 | """ |
19 | 18 |
|
20 | 19 | @property |
@@ -44,13 +43,12 @@ def pf_alias_cancellation_loss( |
44 | 43 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
45 | 44 | """Return the product filter-alias cancellation loss. |
46 | 45 |
|
47 | | - See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ |
48 | | - Alternating sign convention from 0 to N see Strang overview |
49 | | - on the back of the cover. |
| 46 | + See: Strang+Nguyen 105: $$F_0(z) = H_1(-z); F_1(z) = -H_0(-z)$$ Alternating sign |
| 47 | + convention from 0 to N see Strang overview on the back of the cover. |
50 | 48 |
|
51 | 49 | Returns: |
52 | | - The numerical value of the alias cancellation loss, |
53 | | - as well as both loss components for analysis. |
| 50 | + The numerical value of the alias cancellation loss, as well as both loss |
| 51 | + components for analysis. |
54 | 52 | """ |
55 | 53 | dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank |
56 | 54 | m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) |
@@ -78,13 +76,12 @@ def alias_cancellation_loss( |
78 | 76 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
79 | 77 | """Return the alias cancellation loss. |
80 | 78 |
|
81 | | - Implementation of the ac-loss as described |
82 | | - on page 104 of Strang+Nguyen. |
| 79 | + Implementation of the ac-loss as described on page 104 of Strang+Nguyen. |
83 | 80 | $$F_0(z)H_0(-z) + F_1(z)H_1(-z) = 0$$ |
84 | 81 |
|
85 | 82 | Returns: |
86 | | - The numerical value of the alias cancellation loss, |
87 | | - as well as both loss components for analysis. |
| 83 | + The numerical value of the alias cancellation loss, as well as both loss |
| 84 | + components for analysis. |
88 | 85 | """ |
89 | 86 | dec_lo, dec_hi, rec_lo, rec_hi = self.filter_bank |
90 | 87 | m1 = torch.tensor([-1], device=dec_lo.device, dtype=dec_lo.dtype) |
@@ -120,8 +117,8 @@ def perfect_reconstruction_loss( |
120 | 117 | """Return the perfect reconstruction loss. |
121 | 118 |
|
122 | 119 | Returns: |
123 | | - The numerical value of the alias cancellation loss, |
124 | | - as well as both intermediate values for analysis. |
| 120 | + The numerical value of the alias cancellation loss, as well as both |
| 121 | + intermediate values for analysis. |
125 | 122 | """ |
126 | 123 | # Strang 107: Assuming alias cancellation holds: |
127 | 124 | # P(z) = F(z)H(z) |
@@ -174,14 +171,14 @@ def __init__( |
174 | 171 | dec_hi: torch.Tensor, |
175 | 172 | rec_lo: torch.Tensor, |
176 | 173 | rec_hi: torch.Tensor, |
177 | | - ): |
| 174 | + ) -> None: |
178 | 175 | """Create a Product filter object. |
179 | 176 |
|
180 | 177 | Args: |
181 | | - dec_lo (torch.Tensor): Low pass analysis filter. |
182 | | - dec_hi (torch.Tensor): High pass analysis filter. |
183 | | - rec_lo (torch.Tensor): Low pass synthesis filter. |
184 | | - rec_hi (torch.Tensor): High pass synthesis filter. |
| 178 | + dec_lo : Low pass analysis filter. |
| 179 | + dec_hi : High pass analysis filter. |
| 180 | + rec_lo : Low pass synthesis filter. |
| 181 | + rec_hi : High pass synthesis filter. |
185 | 182 | """ |
186 | 183 | super().__init__() |
187 | 184 | self.dec_lo = torch.nn.Parameter(dec_lo) |
@@ -223,29 +220,11 @@ def wavelet_loss(self) -> torch.Tensor: |
223 | 220 | class SoftOrthogonalWavelet(ProductFilter, torch.nn.Module): |
224 | 221 | """Orthogonal wavelets with a soft orthogonality constraint.""" |
225 | 222 |
|
226 | | - def __init__( |
227 | | - self, |
228 | | - dec_lo: torch.Tensor, |
229 | | - dec_hi: torch.Tensor, |
230 | | - rec_lo: torch.Tensor, |
231 | | - rec_hi: torch.Tensor, |
232 | | - ): |
233 | | - """Create a SoftOrthogonalWavelet object. |
234 | | -
|
235 | | - Args: |
236 | | - dec_lo (torch.Tensor): Low pass analysis filter. |
237 | | - dec_hi (torch.Tensor): High pass analysis filter. |
238 | | - rec_lo (torch.Tensor): Low pass synthesis filter. |
239 | | - rec_hi (torch.Tensor): High pass synthesis filter. |
240 | | - """ |
241 | | - super().__init__(dec_lo, dec_hi, rec_lo, rec_hi) |
242 | | - |
243 | 223 | def rec_lo_orthogonality_loss(self) -> torch.Tensor: |
244 | 224 | """Return a Strang inspired soft orthogonality loss. |
245 | 225 |
|
246 | | - See Strang p. 148/149 or Harbo p. 80. |
247 | | - Since L is a convolution matrix, LL^T can be evaluated |
248 | | - trough convolution. |
| 226 | + See Strang p. 148/149 or Harbo p. 80. Since L is a convolution matrix, LL^T can |
| 227 | + be evaluated trough convolution. |
249 | 228 |
|
250 | 229 | Returns: |
251 | 230 | A tensor with the orthogonality constraint value. |
@@ -276,10 +255,9 @@ def rec_lo_orthogonality_loss(self) -> torch.Tensor: |
276 | 255 | def filt_bank_orthogonality_loss(self) -> torch.Tensor: |
277 | 256 | """Return a Jensen+Harbo inspired soft orthogonality loss. |
278 | 257 |
|
279 | | - On Page 79 of the Book Ripples in Mathematics |
280 | | - by Jensen la Cour-Harbo, the constraint |
281 | | - g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters |
282 | | - is presented. A measurement is implemented below. |
| 258 | + On Page 79 of the Book Ripples in Mathematics by Jensen la Cour-Harbo, the |
| 259 | + constraint g0[k] = h0[-k] and g1[k] = h1[-k] for orthogonal filters is |
| 260 | + presented. A measurement is implemented below. |
283 | 261 |
|
284 | 262 | Returns: |
285 | 263 | A tensor with the orthogonality constraint value. |
|
0 commit comments