Skip to content

Commit bc7dd96

Browse files
authored
🔧 chore(model): check if no normalization is in transforms for DRAEM, DSR (#2867)
* chore(model): check if no normalization is in transforms for DRAEM and DSR * updated year in the copyright notice --------- Signed-off-by: Aimira Baitieva <[email protected]>
1 parent 78cd8b1 commit bc7dd96

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

‎src/anomalib/models/image/draem/lightning_model.py‎

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
import torch
2121
from lightning.pytorch.utilities.types import STEP_OUTPUT
2222
from torch import nn
23-
from torchvision.transforms.v2 import Compose, Resize
23+
from torchvision.transforms.v2 import Compose, Normalize, Resize
2424

2525
from anomalib import LearningType
2626
from anomalib.data import Batch
27+
from anomalib.data.transforms.utils import extract_transforms_by_type
2728
from anomalib.data.utils import DownloadInfo, download_and_extract
2829
from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
2930
from anomalib.metrics import Evaluator
@@ -145,6 +146,16 @@ def hook(_, __, output: torch.Tensor) -> None: # noqa: ANN001
145146
self.model.reconstructive_subnetwork.encoder.mp4.register_forward_hook(get_activation("input"))
146147
self.model.reconstructive_subnetwork.encoder.block5.register_forward_hook(get_activation("output"))
147148

149+
def on_train_start(self) -> None:
150+
"""Validates transforms before training begins.
151+
152+
Raises:
153+
ValueError: If transforms contain normalization.
154+
"""
155+
if self.pre_processor and extract_transforms_by_type(self.pre_processor.transform, Normalize):
156+
msg = "Transforms for DRÆM should not contain Normalize."
157+
raise ValueError(msg)
158+
148159
def training_step(self, batch: Batch, *args, **kwargs) -> STEP_OUTPUT:
149160
"""Perform training step for DRAEM.
150161

‎src/anomalib/models/image/dsr/lightning_model.py‎

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (C) 2023-2024 Intel Corporation
1+
# Copyright (C) 2023-2025 Intel Corporation
22
# SPDX-License-Identifier: Apache-2.0
33

44
"""DSR - A Dual Subspace Re-Projection Network for Surface Anomaly Detection.
@@ -39,10 +39,11 @@
3939

4040
import torch
4141
from lightning.pytorch.utilities.types import STEP_OUTPUT, OptimizerLRScheduler
42-
from torchvision.transforms.v2 import Compose, Resize
42+
from torchvision.transforms.v2 import Compose, Normalize, Resize
4343

4444
from anomalib import LearningType
4545
from anomalib.data import Batch
46+
from anomalib.data.transforms.utils import extract_transforms_by_type
4647
from anomalib.data.utils import DownloadInfo, download_and_extract
4748
from anomalib.data.utils.generators.perlin import PerlinAnomalyGenerator
4849
from anomalib.metrics import Evaluator
@@ -182,7 +183,19 @@ def configure_optimizers(
182183
return ({"optimizer": optimizer_d, "lr_scheduler": scheduler_d}, {"optimizer": optimizer_u})
183184

184185
def on_train_start(self) -> None:
185-
"""Load pretrained weights of the discrete model when starting training."""
186+
"""Set up model before training begins.
187+
188+
Performs the following steps:
189+
1. Validates that pre_processor uses no normalization
190+
2. Load pretrained weights of the discrete model
191+
192+
Raises:
193+
ValueError: If transforms contain normalization.
194+
"""
195+
if self.pre_processor and extract_transforms_by_type(self.pre_processor.transform, Normalize):
196+
msg = "Transforms for DSR should not contain Normalize."
197+
raise ValueError(msg)
198+
186199
ckpt: Path = self.prepare_pretrained_model()
187200
self.model.load_pretrained_discrete_model_weights(ckpt, self.device)
188201

0 commit comments

Comments
 (0)