|
1 |
| -# Copyright (C) 2023-2024 Intel Corporation |
| 1 | +# Copyright (C) 2023-2025 Intel Corporation |
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 | 3 |
|
4 | 4 | """DSR - A Dual Subspace Re-Projection Network for Surface Anomaly Detection.
|
|
39 | 39 |
|
40 | 40 | import torch
|
41 | 41 | from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
|
42 |
| -from torchvision.transforms.v2 import Compose, Resize |
| 42 | +from torchvision.transforms.v2 import Compose, Normalize, Resize |
43 | 43 |
|
44 | 44 | from anomalib import LearningType
|
45 | 45 | from anomalib.data import Batch
|
| 46 | +from anomalib.data.transforms.utils import extract_transforms_by_type |
46 | 47 | from anomalib.data.utils import DownloadInfo, download_and_extract
|
47 | 48 | from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
|
48 | 49 | from anomalib.metrics import Evaluator
|
@@ -182,7 +183,19 @@ def configure_optimizers(
|
182 | 183 | return ({"optimizer": optimizer_d, "lr_scheduler": scheduler_d}, {"optimizer": optimizer_u})
|
183 | 184 |
|
184 | 185 | def on_train_start(self) -> None:
|
185 |
| - """Load pretrained weights of the discrete model when starting training.""" |
| 186 | + """Set up model before training begins. |
| 187 | +
|
| 188 | + Performs the following steps: |
| 189 | + 1. Validates that pre_processor uses no normalization |
| 190 | + 2. Load pretrained weights of the discrete model |
| 191 | +
|
| 192 | + Raises: |
| 193 | + ValueError: If transforms contain normalization. |
| 194 | + """ |
| 195 | + if self.pre_processor and extract_transforms_by_type(self.pre_processor.transform, Normalize): |
| 196 | + msg = "Transforms for DSR should not contain Normalize." |
| 197 | + raise ValueError(msg) |
| 198 | + |
186 | 199 | ckpt: Path = self.prepare_pretrained_model()
|
187 | 200 | self.model.load_pretrained_discrete_model_weights(ckpt, self.device)
|
188 | 201 |
|
|
0 commit comments