-
Notifications
You must be signed in to change notification settings - Fork 822
Description
Describe the bug
As the title suggests. I'm using anomalib v1.2.0 (have worked on this code for 1yr+, so didn't want to migrate to v2.x.x), and there is a difference between metrics calculated from engine.test()
, and manually calculated metrics from predictions obtained with engine.predict()
.
What's even weirder is that this happens for 4 "standard" metrics: Precision, Recall, Accuracy and F1score. It does NOT happen for e.g. AUROC metric.
There is a minimal code example below. The example is for Visa dataset. I've reduced the number of images so that training and inference is faster. Ultimately, it doesn't matter which exact images I'm using - I'm just interested whether I'm manually calculating the metrics correctly, i.e. whether anomalib and my implementation yield the same numbers.
CODE EXAMPLE
# %%
import os
import random
import numpy as np
import torch
from anomalib.data import Visa
from anomalib.engine import Engine
from anomalib.metrics import AUROC
from anomalib.models import Dfm
from torchmetrics.classification import Accuracy, F1Score, Precision, Recall
# %%
ALL_IMAGE_METRICS: dict[str, dict] = {
"AUROC": {
"class_path": "anomalib.metrics.AUROC",
"init_args": {},
},
"F1Score": {
"class_path": "torchmetrics.F1Score",
"init_args": {"task": "binary"},
},
"Precision": {
"class_path": "torchmetrics.Precision",
"init_args": {"task": "binary"},
},
"Recall": {
"class_path": "torchmetrics.Recall",
"init_args": {"task": "binary"},
},
"Accuracy": {
"class_path": "torchmetrics.Accuracy",
"init_args": {"task": "binary"},
},
}
ALL_PIXEL_METRICS: dict[str, dict] = {
"AUROC": {
"class_path": "anomalib.metrics.AUROC",
"init_args": {},
},
}
# %%
def seed_everything(seed: int) -> None:
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# %%
seed_everything(0)
visa = Visa(
root="/mnt/extra-storage/thesis/data/visa_local",
category="macaroni2",
train_batch_size=8,
eval_batch_size=8,
task="segmentation",
image_size=(512, 512),
seed=0,
)
model = Dfm()
# %%
engine = Engine(
default_root_dir="/home/luka/Desktop/thesis/debug-metrics",
task="segmentation",
image_metrics=ALL_IMAGE_METRICS,
pixel_metrics=ALL_PIXEL_METRICS,
accelerator="gpu",
devices=1,
max_epochs=1,
logger=False,
)
engine.fit(model=model, datamodule=visa)
# %%
engine.test(model=model, datamodule=visa)
# %%
predictions = engine.predict(model=model, datamodule=visa)
# %%
image_gt = []
image_preds = []
image_scores = []
pixel_anomaly_maps = []
pixel_gt = []
for batch in predictions:
# Image level
image_gt.extend(batch["label"].numpy().flatten())
pred_labels = batch["pred_labels"].long()
image_preds.extend(pred_labels.numpy().flatten())
pred_scores = batch["pred_scores"].float()
image_scores.extend(pred_scores.numpy().flatten())
# Pixel level
pixel_anomaly_maps.extend(batch["anomaly_maps"].float().numpy())
pixel_gt.extend(batch["mask"].long().numpy())
# %%
print(f"{image_gt = }")
print(f"{image_preds = }")
print(f"{image_scores = }")
# %%
precision = Precision(task="binary")
recall = Recall(task="binary")
accuracy = Accuracy(task="binary")
f1 = F1Score(task="binary")
image_auroc = AUROC()
pixel_auroc = AUROC()
precision.update(torch.tensor(image_preds), torch.tensor(image_gt))
recall.update(torch.tensor(image_preds), torch.tensor(image_gt))
accuracy.update(torch.tensor(image_preds), torch.tensor(image_gt))
f1.update(torch.tensor(image_preds), torch.tensor(image_gt))
image_auroc.update(torch.tensor(image_scores), torch.tensor(image_gt))
pixel_auroc.update(torch.tensor(pixel_anomaly_maps), torch.tensor(pixel_gt))
print(f"Image AUROC: {image_auroc.compute().item()}")
print(f"Accuracy: {accuracy.compute().item()}")
print(f"F1Score: {f1.compute().item()}")
print(f"Precision: {precision.compute().item()}")
print(f"Recall: {recall.compute().item()}")
print(f"Pixel AUROC: {pixel_auroc.compute().item()}")
I have tried and disregarded the following options:
- Different models (Fastflow, EfficientAd)
- Verifying that
datamodule.test_dataloader()
anddatamodule.predict_dataloader()
are the same - Seeing if there is a difference between using existing
engine.trainer.image_metrics
and manually instantiating the metrics fromtorchmetrics
. So, instead of usingimage_metrics = engine.model.image_metrics
, I manually create say
from torchmetrics.classification import Accuracy
accuracy = Accuracy(task="binary")
- Calling
accuracy.update(target, preds)
instead ofaccuracy.update(preds, target)
. When you have no clue, you try everything :D
Dataset
Other (please specify in the text field below)
Model
DFM
Steps to reproduce the behavior
- Install anomalib 1.2.0, torch 2.3.1, torchmetrics 1.6.1
- Optionally delete most of the images from Visa, leaving only a few, to speed up training and inference
- Run the minimal code example from above
OS information
OS information:
- OS: Ubuntu 24.04
- Python version: 3.10.15
- Anomalib version: 1.2.0
- PyTorch version: 2.3.1
- CUDA/cuDNN version: 12.1
- GPU models and configuration: 1x GTX 1660
- Any other relevant information: Using the Visa dataset from anomalib, just with most of the images removed so training and inference goes faster. But originally, I discovered the problem on original Visa with all images.
Expected behavior
The metrics between engine.test()
and manual calculation on engine.predict()
predictions are exactly the same.
Screenshots
Not applicable.
Pip/GitHub
pip
What version/branch did you use?
1.2.0
Configuration YAML
Not using YAML.
Logs
### THESE ARE LOGS FROM THE MINIMAL EXAMPLE CODE PROVIDED AT THE START
### NOTE THAT THE TRAINING AND INFERENCE ARE SHORT BECAUSE IT'S NOT FULL VISA
INFO:anomalib.models.components.base.anomaly_module:Initializing Dfm model.
INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/resnet50.a1_in1k)
INFO:timm.models._hub:[timm/resnet50.a1_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Missing keys (fc.weight, fc.bias) discovered while loading pretrained weights. This is expected if model is being adapted.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
WARNING:anomalib.models.components.base.anomaly_module:No implementation of `configure_transforms` was provided in the Lightning model. Using default transforms from the base class. This may not be suitable for your use case. Please override `configure_transforms` in your model.
INFO:anomalib.data.image.visa:Found the dataset and train/test split.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/luka/Desktop/thesis/.venv/lib/python3.10/site-packages/lightning/pytorch/core/optimizer.py:183: `LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer
| Name | Type | Params | Mode
---------------------------------------------------------------------------
0 | model | DFMModel | 8.5 M | train
1 | _transform | Compose | 0 | train
2 | normalization_metrics | MetricCollection | 0 | train
3 | image_threshold | F1AdaptiveThreshold | 0 | train
4 | pixel_threshold | F1AdaptiveThreshold | 0 | train
5 | image_metrics | AnomalibMetricCollection | 0 | train
6 | pixel_metrics | AnomalibMetricCollection | 0 | train
---------------------------------------------------------------------------
8.5 M Trainable params
0 Non-trainable params
8.5 M Total params
34.173 Total estimated model params size (MB)
18 Modules in train mode
174 Modules in eval mode
Epoch 0: 0%| | 0/6 [00:00<?, ?it/s]/home/luka/Desktop/thesis/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/optimization/automatic.py:134: `training_step` returned `None`. If this was on purpose, ignore this warning...
Epoch 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:01<00:00, 4.05it/sINFO:anomalib.models.image.dfm.lightning_model:Aggregating the embedding extracted from the training set. | 0/? [00:00<?, ?it/s]
INFO:anomalib.models.image.dfm.lightning_model:Fitting a PCA and a Gaussian model to dataset.
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:03<00:00, 1.57it/s, pixel_AUROC=0.500]`Trainer.fit` stopped: `max_epochs=1` reached.
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [00:04<00:00, 1.36it/s, pixel_AUROC=0.500]
INFO:anomalib.callbacks.timer:Training took 4.77 seconds
INFO:anomalib.data.image.visa:Found the dataset and train/test split.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00, 0.50it/s]INFO:anomalib.callbacks.timer:Testing took 7.424871206283569 seconds
Throughput (batch_size=8) : 2.6936494175244774 FPS
Testing DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:06<00:00, 0.43it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ Test metric ┃ DataLoader 0 ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│ image_AUROC │ 0.9800000190734863 │
│ image_Accuracy │ 0.8999999761581421 │
│ image_F1Score │ 0.8999999761581421 │
│ image_Precision │ 0.8999999761581421 │
│ image_Recall │ 0.8999999761581421 │
│ pixel_AUROC │ 0.8799090385437012 │
└───────────────────────────┴───────────────────────────┘
WARNING:anomalib.engine.engine:ckpt_path is not provided. Model weights will not be loaded.
INFO:anomalib.data.image.visa:Found the dataset and train/test split.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Predicting DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:05<00:00, 0.58it/s]
image_gt = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
image_preds = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
image_scores = [0.8788439, 0.5246862, 1.0, 0.7724421, 0.98713803, 0.6087804, 0.70329773, 0.9312367, 0.5, 0.88264537, 0.22918293, 0.40289715, 0.4770017, 0.17081264, 0.08232936, 0.3444077, 0.17680398, 0.15485987, 0.3225487, 0.5389916]
/home/luka/Desktop/thesis/notebooks/minimal_example.py:129: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:274.)
pixel_auroc.update(torch.tensor(pixel_anomaly_maps), torch.tensor(pixel_gt))
Image AUROC: 0.9800000190734863
Accuracy: 0.949999988079071
F1Score: 0.9523809552192688
Precision: 0.9090909361839294
Recall: 1.0
Pixel AUROC: 0.8799090385437012
Code of Conduct
- I agree to follow this project's Code of Conduct