diff --git a/src/lightly_train/_commands/train_task.py b/src/lightly_train/_commands/train_task.py index 7b73c98a5..de7e74aaa 100644 --- a/src/lightly_train/_commands/train_task.py +++ b/src/lightly_train/_commands/train_task.py @@ -27,6 +27,7 @@ from lightly_train._data.mask_semantic_segmentation_dataset import ( MaskSemanticSegmentationDataArgs, ) +from lightly_train._data.task_dataset import TaskDataset from lightly_train._loggers.task_logger_args import TaskLoggerArgs from lightly_train._task_checkpoint import TaskSaveCheckpointArgs from lightly_train._task_models.train_model import TrainModelArgs @@ -247,10 +248,12 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: ignore_index=config.data.ignore_index, ) train_transform = helpers.get_train_transform( - train_model_cls=train_model_cls, train_transform_args=train_transform_args + train_model_cls=train_model_cls, + train_transform_args=train_transform_args, ) val_transform = helpers.get_val_transform( - train_model_cls=train_model_cls, val_transform_args=val_transform_args + train_model_cls=train_model_cls, + val_transform_args=val_transform_args, ) with helpers.get_dataset_temp_mmap_path( @@ -258,13 +261,13 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: ) as train_mmap_filepath, helpers.get_dataset_temp_mmap_path( fabric=fabric, data=config.data.val.images ) as val_mmap_filepath: - train_dataset = helpers.get_dataset( + train_dataset: TaskDataset = helpers.get_dataset( fabric=fabric, dataset_args=config.data.get_train_args(), transform=train_transform, mmap_filepath=train_mmap_filepath, ) - val_dataset = helpers.get_dataset( + val_dataset: TaskDataset = helpers.get_dataset( fabric=fabric, dataset_args=config.data.get_val_args(), transform=val_transform, @@ -305,6 +308,7 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: train_dataloader = helpers.get_train_dataloader( fabric=fabric, dataset=train_dataset, + transform_args=train_transform_args, batch_size=config.batch_size, num_workers=config.num_workers, loader_args=config.loader_args, @@ -313,6 +317,7 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: val_dataloader = helpers.get_val_dataloader( fabric=fabric, dataset=val_dataset, + transform_args=val_transform_args, batch_size=config.batch_size, num_workers=config.num_workers, loader_args=config.loader_args, diff --git a/src/lightly_train/_commands/train_task_helpers.py b/src/lightly_train/_commands/train_task_helpers.py index 55c721f99..fee8669a0 100644 --- a/src/lightly_train/_commands/train_task_helpers.py +++ b/src/lightly_train/_commands/train_task_helpers.py @@ -11,7 +11,6 @@ import hashlib import json import logging -from functools import partial from json import JSONEncoder from pathlib import Path from typing import Any, Generator, Iterable, Literal, Mapping @@ -22,7 +21,7 @@ from lightning_fabric import utilities as fabric_utilities from lightning_fabric.loggers.logger import Logger as FabricLogger from torch import Tensor -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import DataLoader from lightly_train._configs import validate from lightly_train._data import cache @@ -32,9 +31,9 @@ Primitive, ) from lightly_train._data.mask_semantic_segmentation_dataset import ( - MaskSemanticSegmentationDataset, MaskSemanticSegmentationDatasetArgs, ) +from lightly_train._data.task_dataset import TaskDataset from lightly_train._env import Env from lightly_train._loggers.mlflow import MLFlowLogger from lightly_train._loggers.task_logger_args import TaskLoggerArgs @@ -62,8 +61,6 @@ TaskTransformArgs, ) from lightly_train.types import ( - MaskSemanticSegmentationBatch, - MaskSemanticSegmentationDatasetItem, PathLike, TaskDatasetItem, ) @@ -414,7 +411,7 @@ def get_dataset( dataset_args: MaskSemanticSegmentationDatasetArgs, transform: TaskTransform, mmap_filepath: Path, -) -> MaskSemanticSegmentationDataset: +) -> TaskDataset: image_info = dataset_args.list_image_info() dataset_cls = dataset_args.get_dataset_cls() @@ -431,35 +428,19 @@ def get_dataset( ) -# TODO(Guarin, 08/25): Move this function to the _data module. -def collate_fn( - batch: list[MaskSemanticSegmentationDatasetItem], split: str -) -> MaskSemanticSegmentationBatch: - # Prepare the batch without any stacking. - images = [item["image"] for item in batch] - masks = [item["mask"] for item in batch] - - out: MaskSemanticSegmentationBatch = { - "image_path": [item["image_path"] for item in batch], - # Stack images during training as they all have the same shape. - # During validation every image can have a different shape. - "image": torch.stack(images) if split == "train" else images, - "mask": torch.stack(masks) if split == "train" else masks, - "binary_masks": [item["binary_masks"] for item in batch], - } - - return out - - def get_train_dataloader( fabric: Fabric, - dataset: Dataset[TaskDatasetItem], + dataset: TaskDataset, + transform_args: TaskTransformArgs, batch_size: int, num_workers: int, loader_args: dict[str, Any] | None = None, ) -> DataLoader[TaskDatasetItem]: timeout = Env.LIGHTLY_TRAIN_DATALOADER_TIMEOUT_SEC.value if num_workers > 0 else 0 # TODO(Guarin, 07/25): Persistent workers by default? + collate_fn = dataset.batch_collate_fn_cls( + split="train", transform_args=transform_args + ) dataloader_kwargs: dict[str, Any] = dict( dataset=dataset, batch_size=batch_size // fabric.world_size, @@ -467,7 +448,7 @@ def get_train_dataloader( num_workers=num_workers, drop_last=True, timeout=timeout, - collate_fn=partial(collate_fn, split="train"), + collate_fn=collate_fn, ) if loader_args is not None: logger.debug(f"Using additional dataloader arguments {loader_args}.") @@ -481,12 +462,16 @@ def get_train_dataloader( def get_val_dataloader( fabric: Fabric, - dataset: Dataset[TaskDatasetItem], + dataset: TaskDataset, + transform_args: TaskTransformArgs, batch_size: int, num_workers: int, loader_args: dict[str, Any] | None = None, ) -> DataLoader[TaskDatasetItem]: timeout = Env.LIGHTLY_TRAIN_DATALOADER_TIMEOUT_SEC.value if num_workers > 0 else 0 + collate_fn = dataset.batch_collate_fn_cls( + split="val", transform_args=transform_args + ) dataloader_kwargs: dict[str, Any] = dict( dataset=dataset, batch_size=batch_size // fabric.world_size, @@ -494,7 +479,7 @@ def get_val_dataloader( num_workers=num_workers, drop_last=False, timeout=timeout, - collate_fn=partial(collate_fn, split="validation"), + collate_fn=collate_fn, ) if loader_args is not None: logger.debug(f"Using additional dataloader arguments {loader_args}.") diff --git a/src/lightly_train/_data/mask_semantic_segmentation_dataset.py b/src/lightly_train/_data/mask_semantic_segmentation_dataset.py index 8022b4581..c06825859 100644 --- a/src/lightly_train/_data/mask_semantic_segmentation_dataset.py +++ b/src/lightly_train/_data/mask_semantic_segmentation_dataset.py @@ -15,12 +15,16 @@ import torch from pydantic import AliasChoices, Field, TypeAdapter, field_validator from torch import Tensor -from torch.utils.data import Dataset from lightly_train._configs.config import PydanticConfig from lightly_train._data import file_helpers from lightly_train._data.file_helpers import ImageMode +from lightly_train._data.task_batch_collation import ( + BaseCollateFunction, + MaskSemanticSegmentationCollateFunction, +) from lightly_train._data.task_data_args import TaskDataArgs +from lightly_train._data.task_dataset import TaskDataset from lightly_train._env import Env from lightly_train._transforms.semantic_segmentation_transform import ( SemanticSegmentationTransform, @@ -55,16 +59,20 @@ class MultiChannelClassInfo(PydanticConfig): ClassInfo = Union[MultiChannelClassInfo, SingleChannelClassInfo] -class MaskSemanticSegmentationDataset(Dataset[MaskSemanticSegmentationDatasetItem]): +class MaskSemanticSegmentationDataset(TaskDataset): + batch_collate_fn_cls: ClassVar[type[BaseCollateFunction]] = ( + MaskSemanticSegmentationCollateFunction + ) + def __init__( self, dataset_args: MaskSemanticSegmentationDatasetArgs, image_info: Sequence[dict[str, str]], transform: SemanticSegmentationTransform, ): + super().__init__(transform=transform) self.args = dataset_args self.filepaths = image_info - self.transform = transform self.ignore_index = dataset_args.ignore_index # Get the class mapping. diff --git a/src/lightly_train/_data/task_batch_collation.py b/src/lightly_train/_data/task_batch_collation.py new file mode 100644 index 000000000..8b541830f --- /dev/null +++ b/src/lightly_train/_data/task_batch_collation.py @@ -0,0 +1,143 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from __future__ import annotations + +from typing import Literal + +import numpy as np +import torch + +from lightly_train._transforms.object_detection_transform import ( + ObjectDetectionTransformArgs, +) +from lightly_train._transforms.scale_jitter import ScaleJitter +from lightly_train._transforms.task_transform import TaskTransformArgs +from lightly_train.types import ( + MaskSemanticSegmentationBatch, + MaskSemanticSegmentationDatasetItem, + ObjectDetectionBatch, + ObjectDetectionDatasetItem, +) + + +class BaseCollateFunction: + def __init__( + self, split: Literal["train", "val"], transform_args: TaskTransformArgs + ): + self.split = split + self.transform_args = transform_args + + +class MaskSemanticSegmentationCollateFunction(BaseCollateFunction): + def __call__( + self, batch: list[MaskSemanticSegmentationDatasetItem] + ) -> MaskSemanticSegmentationBatch: + # Prepare the batch without any stacking. + images = [item["image"] for item in batch] + masks = [item["mask"] for item in batch] + + out: MaskSemanticSegmentationBatch = { + "image_path": [item["image_path"] for item in batch], + # Stack images during training as they all have the same shape. + # During validation every image can have a different shape. + "image": torch.stack(images) if self.split == "train" else images, + "mask": torch.stack(masks) if self.split == "train" else masks, + "binary_masks": [item["binary_masks"] for item in batch], + } + + return out + + +class ObjectDetectionCollateFunction(BaseCollateFunction): + def __init__( + self, split: Literal["train", "val"], transform_args: TaskTransformArgs + ): + super().__init__(split, transform_args) + assert isinstance(transform_args, ObjectDetectionTransformArgs) + self.scale_jitter: ScaleJitter | None + if transform_args.scale_jitter is not None: + if ( + transform_args.scale_jitter.min_scale is None + or transform_args.scale_jitter.max_scale is None + ): + scale_range = None + else: + scale_range = ( + transform_args.scale_jitter.min_scale, + transform_args.scale_jitter.max_scale, + ) + self.scale_jitter = ScaleJitter( + sizes=transform_args.scale_jitter.sizes, + target_size=transform_args.image_size, + scale_range=scale_range, + num_scales=transform_args.scale_jitter.num_scales, + divisible_by=transform_args.scale_jitter.divisible_by, + p=transform_args.scale_jitter.prob, + ) + else: + self.scale_jitter = None + + def __call__(self, batch: list[ObjectDetectionDatasetItem]) -> ObjectDetectionBatch: + if self.scale_jitter is not None: + # Turn into numpy again. + batch_np = [ + { + "image_path": item["image_path"], + "image": item["image"].numpy(), + "bboxes": item["bboxes"].numpy(), + "classes": item["classes"].numpy(), + } + for item in batch + ] + + # Apply transform. + seed = np.random.randint(0, 1_000_000) + self.scale_jitter.global_step = seed + images = [] + bboxes = [] + classes = [] + for item in batch_np: + out = self.scale_jitter( + image=item["image"], + bboxes=item["bboxes"], + class_labels=item["classes"], + ) + images.append(out["image"]) + bboxes.append(out["bboxes"]) + classes.append(out["class_labels"]) + + # Old versions of albumentations return classes/boxes as a list. + bboxes = [ + bbox if isinstance(bbox, np.ndarray) else np.array(bbox) + for bbox in bboxes + ] + classes = [ + cls_ if isinstance(cls_, np.ndarray) else np.array(cls_) + for cls_ in classes + ] + + # Turn back into torch tensors. + images = [torch.from_numpy(img).to(torch.float32) for img in images] + bboxes = [torch.from_numpy(bbox).to(torch.float32) for bbox in bboxes] + classes = [torch.from_numpy(cls).to(torch.int64) for cls in classes] + + out_: ObjectDetectionBatch = { + "image_path": [item["image_path"] for item in batch], + "image": torch.stack(images), + "bboxes": bboxes, + "classes": classes, + } + return out_ + else: + out_ = { + "image_path": [item["image_path"] for item in batch], + "image": torch.stack([item["image"] for item in batch]), + "bboxes": [item["bboxes"] for item in batch], + "classes": [item["classes"] for item in batch], + } + return out_ diff --git a/src/lightly_train/_data/task_dataset.py b/src/lightly_train/_data/task_dataset.py new file mode 100644 index 000000000..f398fdbe0 --- /dev/null +++ b/src/lightly_train/_data/task_dataset.py @@ -0,0 +1,33 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from __future__ import annotations + +from typing import ClassVar + +from torch.utils.data import Dataset + +from lightly_train._data.task_batch_collation import BaseCollateFunction +from lightly_train._transforms.task_transform import TaskTransform +from lightly_train.types import TaskDatasetItem + + +class TaskDataset(Dataset[TaskDatasetItem]): + batch_collate_fn_cls: ClassVar[type[BaseCollateFunction]] = BaseCollateFunction + + def __init__(self, transform: TaskTransform) -> None: + self._transform = transform + + @property + def transform(self) -> TaskTransform: + return self._transform + + def __len__(self) -> int: + raise NotImplementedError() + + def __getitem__(self, index: int) -> TaskDatasetItem: + raise NotImplementedError() diff --git a/src/lightly_train/_data/yolo_object_detection_dataset.py b/src/lightly_train/_data/yolo_object_detection_dataset.py index 18806628e..ae724e652 100644 --- a/src/lightly_train/_data/yolo_object_detection_dataset.py +++ b/src/lightly_train/_data/yolo_object_detection_dataset.py @@ -8,30 +8,38 @@ from __future__ import annotations from pathlib import Path -from typing import Literal, Sequence +from typing import ClassVar, Literal, Sequence import numpy as np import pydantic import torch -from torch.utils.data import Dataset from lightly_train._configs.config import PydanticConfig from lightly_train._data import file_helpers +from lightly_train._data.task_batch_collation import ( + BaseCollateFunction, + ObjectDetectionCollateFunction, +) from lightly_train._data.task_data_args import TaskDataArgs +from lightly_train._data.task_dataset import TaskDataset from lightly_train._transforms.task_transform import TaskTransform from lightly_train.types import ImageFilename, ObjectDetectionDatasetItem, PathLike -class YOLOObjectDetectionDataset(Dataset[ObjectDetectionDatasetItem]): +class YOLOObjectDetectionDataset(TaskDataset): + batch_collate_fn_cls: ClassVar[type[BaseCollateFunction]] = ( + ObjectDetectionCollateFunction + ) + def __init__( self, dataset_args: YOLOObjectDetectionDatasetArgs, image_filenames: Sequence[ImageFilename], transform: TaskTransform, ) -> None: + super().__init__(transform=transform) self.args = dataset_args self.image_filenames = image_filenames - self.transform = transform def __len__(self) -> int: return len(self.image_filenames) diff --git a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py index f433e7346..2b71c04d5 100644 --- a/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py +++ b/src/lightly_train/_task_models/dinov2_eomt_semantic_segmentation/transforms.py @@ -7,7 +7,7 @@ # from __future__ import annotations -from typing import Literal +from typing import Literal, Sequence from pydantic import Field @@ -40,10 +40,13 @@ class DINOv2EoMTSemanticSegmentationColorJitterArgs(ColorJitterArgs): class DINOv2EoMTSemanticSegmentationScaleJitterArgs(ScaleJitterArgs): - min_scale: float = 0.5 - max_scale: float = 2.0 - num_scales: int = 20 + sizes: Sequence[tuple[int, int]] | None = None + min_scale: float | None = 0.5 + max_scale: float | None = 2.0 + num_scales: int | None = 20 prob: float = 1.0 + # TODO: Lionel(09/25): This is currently not used. + divisible_by: int | None = None class DINOv2EoMTSemanticSegmentationSmallestMaxSizeArgs(SmallestMaxSizeArgs): diff --git a/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py b/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py index 2e56525e0..13e01262f 100644 --- a/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py +++ b/src/lightly_train/_task_models/dinov2_linear_semantic_segmentation/transforms.py @@ -7,7 +7,7 @@ # from __future__ import annotations -from typing import Literal +from typing import Literal, Sequence from pydantic import Field @@ -40,10 +40,13 @@ class DINOv2LinearSemanticSegmentationColorJitterArgs(ColorJitterArgs): class DINOv2LinearSemanticSegmentationScaleJitterArgs(ScaleJitterArgs): - min_scale: float = 0.5 - max_scale: float = 2.0 - num_scales: int = 20 + sizes: Sequence[tuple[int, int]] | None = None + min_scale: float | None = 0.5 + max_scale: float | None = 2.0 + num_scales: int | None = 20 prob: float = 1.0 + # TODO: Lionel(09/25): This is currently not used. + divisible_by: int | None = None class DINOv2LinearSemanticSegmentationSmallestMaxSizeArgs(SmallestMaxSizeArgs): diff --git a/src/lightly_train/_task_models/dinov2_ltdetr_object_detection/transforms.py b/src/lightly_train/_task_models/dinov2_ltdetr_object_detection/transforms.py index 7180a13d3..12624c8f1 100644 --- a/src/lightly_train/_task_models/dinov2_ltdetr_object_detection/transforms.py +++ b/src/lightly_train/_task_models/dinov2_ltdetr_object_detection/transforms.py @@ -7,17 +7,21 @@ # from __future__ import annotations +from typing import Literal, Sequence + from albumentations import BboxParams from pydantic import Field from lightly_train._transforms.object_detection_transform import ( + ObjectDetectionTransform, ObjectDetectionTransformArgs, ) from lightly_train._transforms.transform import ( RandomFlipArgs, RandomPhotometricDistortArgs, RandomZoomOutArgs, - ResizeArgs, + ScaleJitterArgs, + StopPolicyArgs, ) @@ -42,23 +46,47 @@ class DINOv2LTDetrObjectDetectionRandomFlipArgs(RandomFlipArgs): vertical_prob: float = 0.0 -class DINOv2LTDetrObjectDetectionResizeArgs(ResizeArgs): - height: int = 644 - width: int = 644 +class DINOv2LTDetrObjectDetectionScaleJitterArgs(ScaleJitterArgs): + sizes: Sequence[tuple[int, int]] | None = [ + (490, 490), + (518, 518), + (546, 546), + (588, 588), + (616, 616), + (644, 644), + (644, 644), + (644, 644), + (686, 686), + (714, 714), + (742, 742), + (770, 770), + (812, 812), + ] + min_scale: float | None = 0.76 + max_scale: float | None = 1.27 + num_scales: int | None = 13 + prob: float = 1.0 + # The model is patch 14. + divisible_by: int | None = 14 -class DINOv2LTDetrObjectDetectionTransformArgs(ObjectDetectionTransformArgs): - photometric_distort: DINOv2LTDetrObjectDetectionRandomPhotometricDistortArgs = ( - Field(default_factory=DINOv2LTDetrObjectDetectionRandomPhotometricDistortArgs) - ) - random_zoom_out: DINOv2LTDetrObjectDetectionRandomZoomOutArgs = Field( +class DINOv2LTDetrObjectDetectionTrainTransformArgs(ObjectDetectionTransformArgs): + channel_drop: None = None + num_channels: int | Literal["auto"] = "auto" + photometric_distort: ( + DINOv2LTDetrObjectDetectionRandomPhotometricDistortArgs | None + ) = Field(default_factory=DINOv2LTDetrObjectDetectionRandomPhotometricDistortArgs) + random_zoom_out: DINOv2LTDetrObjectDetectionRandomZoomOutArgs | None = Field( default_factory=DINOv2LTDetrObjectDetectionRandomZoomOutArgs ) - random_flip: DINOv2LTDetrObjectDetectionRandomFlipArgs = Field( + random_flip: DINOv2LTDetrObjectDetectionRandomFlipArgs | None = Field( default_factory=DINOv2LTDetrObjectDetectionRandomFlipArgs ) - resize: DINOv2LTDetrObjectDetectionResizeArgs = Field( - default_factory=DINOv2LTDetrObjectDetectionResizeArgs + image_size: tuple[int, int] = (644, 644) + # TODO: Lionel (09/25): Remove None, once the stop policy is implemented. + stop_policy: StopPolicyArgs | None = None + scale_jitter: ScaleJitterArgs | None = Field( + default_factory=DINOv2LTDetrObjectDetectionScaleJitterArgs ) # We use the YOLO format internally for now. bbox_params: BboxParams = Field( @@ -66,3 +94,26 @@ class DINOv2LTDetrObjectDetectionTransformArgs(ObjectDetectionTransformArgs): format="yolo", label_fields=["class_labels"], min_width=0.0, min_height=0.0 ), ) + + +class DINOv2LTDetrObjectDetectionValTransformArgs(ObjectDetectionTransformArgs): + channel_drop: None = None + num_channels: int | Literal["auto"] = "auto" + photometric_distort: None = None + random_zoom_out: None = None + random_flip: None = None + image_size: tuple[int, int] = (644, 644) + stop_policy: None = None + bbox_params: BboxParams = Field( + default_factory=lambda: BboxParams( + format="yolo", label_fields=["class_labels"], min_width=0.0, min_height=0.0 + ), + ) + + +class DINOv2LTDetrObjectDetectionTrainTransform(ObjectDetectionTransform): + transform_args_cls = DINOv2LTDetrObjectDetectionTrainTransformArgs + + +class DINOv2LTDetrObjectDetectionValTransform(ObjectDetectionTransform): + transform_args_cls = DINOv2LTDetrObjectDetectionValTransformArgs diff --git a/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py b/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py index 8d0954373..d2978a06a 100644 --- a/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py +++ b/src/lightly_train/_task_models/dinov3_eomt_semantic_segmentation/transforms.py @@ -7,7 +7,7 @@ # from __future__ import annotations -from typing import Literal +from typing import Literal, Sequence from pydantic import Field @@ -40,10 +40,13 @@ class DINOv3EoMTSemanticSegmentationColorJitterArgs(ColorJitterArgs): class DINOv3EoMTSemanticSegmentationScaleJitterArgs(ScaleJitterArgs): - min_scale: float = 0.5 - max_scale: float = 2.0 - num_scales: int = 20 + sizes: Sequence[tuple[int, int]] | None = None + min_scale: float | None = 0.5 + max_scale: float | None = 2.0 + num_scales: int | None = 20 prob: float = 1.0 + # TODO: Lionel(09/25): This is currently not used. + divisible_by: int | None = None class DINOv3EoMTSemanticSegmentationSmallestMaxSizeArgs(SmallestMaxSizeArgs): diff --git a/src/lightly_train/_transforms/object_detection_transform.py b/src/lightly_train/_transforms/object_detection_transform.py index 53a3039e4..fb835d278 100644 --- a/src/lightly_train/_transforms/object_detection_transform.py +++ b/src/lightly_train/_transforms/object_detection_transform.py @@ -7,12 +7,21 @@ # from __future__ import annotations +from typing import Literal + import numpy as np -from albumentations import BboxParams +from albumentations import BboxParams, Compose, HorizontalFlip, VerticalFlip +from albumentations.pytorch.transforms import ToTensorV2 from numpy.typing import NDArray +from pydantic import ConfigDict from torch import Tensor from typing_extensions import NotRequired +from lightly_train._transforms.channel_drop import ChannelDrop +from lightly_train._transforms.random_photometric_distort import ( + RandomPhotometricDistort, +) +from lightly_train._transforms.random_zoom_out import RandomZoomOut from lightly_train._transforms.task_transform import ( TaskTransform, TaskTransformArgs, @@ -20,10 +29,12 @@ TaskTransformOutput, ) from lightly_train._transforms.transform import ( + ChannelDropArgs, RandomFlipArgs, RandomPhotometricDistortArgs, RandomZoomOutArgs, - ResizeArgs, + ScaleJitterArgs, + StopPolicyArgs, ) from lightly_train.types import NDArrayImage @@ -41,17 +52,145 @@ class ObjectDetectionTransformOutput(TaskTransformOutput): class ObjectDetectionTransformArgs(TaskTransformArgs): + channel_drop: ChannelDropArgs | None + num_channels: int | Literal["auto"] photometric_distort: RandomPhotometricDistortArgs | None random_zoom_out: RandomZoomOutArgs | None + # TODO: Lionel (09/25): Add RandomIoUCrop random_flip: RandomFlipArgs | None - resize: ResizeArgs | None - bbox_params: BboxParams + image_size: tuple[int, int] + # TODO: Lionel (09/25): Add Normalize + stop_policy: StopPolicyArgs | None + scale_jitter: ScaleJitterArgs | None + bbox_params: BboxParams | None + + # Necessary for the StopPolicyArgs, which are not serializable by pydantic. + model_config = ConfigDict(arbitrary_types_allowed=True) + + def resolve_auto(self) -> None: + if self.num_channels == "auto": + if self.channel_drop is not None: + self.num_channels = self.channel_drop.num_channels_keep + else: + # TODO: Lionel (09/25): Get num_channels from normalization. + self.num_channels = 3 + + height, width = self.image_size + for field_name in self.__class__.model_fields: + field = getattr(self, field_name) + if hasattr(field, "resolve_auto"): + field.resolve_auto(height=height, width=width) + + def resolve_incompatible(self) -> None: + # TODO: Lionel (09/25): Add checks for incompatible args. + pass class ObjectDetectionTransform(TaskTransform): - transform_args_cls = ObjectDetectionTransformArgs + transform_args_cls: type[ObjectDetectionTransformArgs] = ( + ObjectDetectionTransformArgs + ) + + def __init__( + self, + transform_args: ObjectDetectionTransformArgs, + ) -> None: + super().__init__(transform_args=transform_args) + + self.transform_args: ObjectDetectionTransformArgs = transform_args + self.stop_step = ( + transform_args.stop_policy.stop_step if transform_args.stop_policy else None + ) + + # TODO: Lionel (09/25): Implement stopping of certain augmentations after some steps. + if self.stop_step is not None: + raise NotImplementedError( + "Stopping certain augmentations after some steps is not implemented yet." + ) + self.global_step = 0 # Currently hardcoded, will be set from outside. + self.stop_ops = ( + transform_args.stop_policy.ops if transform_args.stop_policy else set() + ) + self.past_stop = False + + self.individual_transforms = [] - def __call__( # type: ignore[empty-body] + if transform_args.channel_drop is not None: + self.individual_transforms += [ + ChannelDrop( + num_channels_keep=transform_args.channel_drop.num_channels_keep, + weight_drop=transform_args.channel_drop.weight_drop, + ) + ] + + if transform_args.photometric_distort is not None: + self.individual_transforms += [ + RandomPhotometricDistort( + brightness=transform_args.photometric_distort.brightness, + contrast=transform_args.photometric_distort.contrast, + saturation=transform_args.photometric_distort.saturation, + hue=transform_args.photometric_distort.hue, + p=transform_args.photometric_distort.prob, + ) + ] + + if transform_args.random_zoom_out is not None: + self.individual_transforms += [ + RandomZoomOut( + fill=transform_args.random_zoom_out.fill, + side_range=transform_args.random_zoom_out.side_range, + p=transform_args.random_zoom_out.prob, + ) + ] + + if transform_args.random_flip is not None: + if transform_args.random_flip.horizontal_prob > 0.0: + self.individual_transforms += [ + HorizontalFlip(p=transform_args.random_flip.horizontal_prob) + ] + if transform_args.random_flip.vertical_prob > 0.0: + self.individual_transforms += [ + VerticalFlip(p=transform_args.random_flip.vertical_prob) + ] + + self.individual_transforms += [ + ToTensorV2(), + ] + + self.transform = Compose( + self.individual_transforms, + bbox_params=transform_args.bbox_params, + ) + + def __call__( self, input: ObjectDetectionTransformInput ) -> ObjectDetectionTransformOutput: - pass + # Adjust transform after stop_step is reached. + if ( + self.stop_step is not None + and self.global_step >= self.stop_step + and not self.past_stop + ): + self.individual_transforms = [ + t for t in self.individual_transforms if type(t) not in self.stop_ops + ] + self.transform = Compose( + self.individual_transforms, + bbox_params=self.transform_args.bbox_params, + ) + self.past_stop = True + + transformed = self.transform( + image=input["image"], + bboxes=input["bboxes"], + class_labels=input["class_labels"], + ) + + # TODO: Lionel (09/25): Remove in favor of Normalize transform. + transformed["image"] = transformed["image"] / 255.0 + + return { + "image": transformed["image"], + "bboxes": transformed["bboxes"], + "class_labels": transformed["class_labels"], + } diff --git a/src/lightly_train/_transforms/scale_jitter.py b/src/lightly_train/_transforms/scale_jitter.py new file mode 100644 index 000000000..fbbba0b81 --- /dev/null +++ b/src/lightly_train/_transforms/scale_jitter.py @@ -0,0 +1,109 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from __future__ import annotations + +from typing import Any, Sequence + +import numpy as np +from albumentations import Resize +from albumentations.core.transforms_interface import DualTransform +from numpy.typing import NDArray + + +class ScaleJitter(DualTransform): # type: ignore[misc] + def __init__( + self, + *, + sizes: Sequence[tuple[int, int]] | None, + target_size: tuple[int, int] | None = None, + scale_range: tuple[float, float] | None = None, + num_scales: int | None = None, + divisible_by: int | None = None, + p: float = 1.0, + step_seeding: bool = False, + seed_offset: int = 0, + ): + super().__init__(p=1.0) + if sizes is not None and any( + [s is None for s in [target_size, scale_range, num_scales]] + ): + raise ValueError( + "If sizes is provided, target_size, scale_range, num_scales must be None." + ) + if sizes is None and any( + [s is None for s in [target_size, scale_range, num_scales]] + ): + raise ValueError( + "If sizes is not provided, target_size, scale_range and num_scales must be provided." + ) + self.sizes = sizes + self.target_size = target_size + self.scale_range = scale_range + self.divisible_by = divisible_by + self.p = p + self.seed_offset = seed_offset + self.step_seeding = step_seeding + + self._step = 0 + + if not sizes: + assert target_size is not None + assert scale_range is not None + assert num_scales is not None + factors = np.linspace( + start=scale_range[0], + stop=scale_range[1], + num=num_scales, + ) + self.heights = (factors * target_size[0]).astype(np.int64) + self.widths = (factors * target_size[1]).astype(np.int64) + else: + self.heights = np.array([s[0] for s in sizes], dtype=np.int64) + self.widths = np.array([s[1] for s in sizes], dtype=np.int64) + + if divisible_by is not None: + self.heights = ( + np.round(self.heights / divisible_by) * divisible_by + ).astype(np.int64) + self.widths = (np.round(self.widths / divisible_by) * divisible_by).astype( + np.int64 + ) + + self.transforms = [ + Resize(height=h, width=w) for h, w in zip(self.heights, self.widths) + ] + + @property + def step(self) -> int: + return self._step + + @step.setter + def step(self, step: int) -> None: + self._step = step + + def get_params(self) -> dict[str, Any]: + if self.step_seeding: + rng = np.random.default_rng(self.step + self.seed_offset) + return {"idx": rng.integers(0, len(self.transforms))} + else: + return {"idx": np.random.randint(0, len(self.transforms))} + + def apply( + self, img: NDArray[np.int64], idx: int, **params: Any + ) -> NDArray[np.int64]: + return self.transforms[idx].apply(img=img, **params) # type: ignore[no-any-return] + + def apply_to_bboxes( + self, bboxes: NDArray[np.float64], idx: int, **params: Any + ) -> NDArray[np.float64]: + return self.transforms[idx].apply_to_bboxes(bboxes, **params) # type: ignore[no-any-return] + + def apply_to_mask( + self, mask: NDArray[np.int64], idx: int, **params: Any + ) -> NDArray[np.int64]: + return self.transforms[idx].apply_to_mask(mask, **params) # type: ignore[no-any-return] diff --git a/src/lightly_train/_transforms/semantic_segmentation_transform.py b/src/lightly_train/_transforms/semantic_segmentation_transform.py index 810c64632..574ef72f4 100644 --- a/src/lightly_train/_transforms/semantic_segmentation_transform.py +++ b/src/lightly_train/_transforms/semantic_segmentation_transform.py @@ -22,6 +22,7 @@ RandomCrop, Resize, SmallestMaxSize, + VerticalFlip, ) from albumentations.pytorch import ToTensorV2 from torch import Tensor @@ -67,6 +68,7 @@ class SemanticSegmentationTransformArgs(TaskTransformArgs): normalize: NormalizeArgs random_flip: RandomFlipArgs | None color_jitter: ColorJitterArgs | None + # TODO: Lionel(09/25): These are currently not fully used. scale_jitter: ScaleJitterArgs | None smallest_max_size: SmallestMaxSizeArgs | None random_crop: RandomCropArgs | None @@ -123,8 +125,11 @@ class SemanticSegmentationTransform(TaskTransform): SemanticSegmentationTransformArgs ) - def __init__(self, transform_args: SemanticSegmentationTransformArgs) -> None: - super().__init__(transform_args) + def __init__( + self, + transform_args: SemanticSegmentationTransformArgs, + ) -> None: + super().__init__(transform_args=transform_args) # Initialize the list of transforms to apply. transform: list[BasicTransform] = [] @@ -138,8 +143,13 @@ def __init__(self, transform_args: SemanticSegmentationTransformArgs) -> None: ] if transform_args.scale_jitter is not None: + # TODO (Lionel, 09/25): Use our custom ScaleJitter transform. + # This follows recommendation on how to replace torchvision ScaleJitter with # albumentations: https://albumentations.ai/docs/torchvision-kornia2albumentations/ + assert transform_args.scale_jitter.min_scale is not None + assert transform_args.scale_jitter.max_scale is not None + assert transform_args.scale_jitter.num_scales is not None scales = np.linspace( start=transform_args.scale_jitter.min_scale, stop=transform_args.scale_jitter.max_scale, @@ -185,7 +195,12 @@ def __init__(self, transform_args: SemanticSegmentationTransformArgs) -> None: # Optionally apply random horizontal flip. if transform_args.random_flip is not None: - transform += [HorizontalFlip(p=transform_args.random_flip.horizontal_prob)] + if transform_args.random_flip.horizontal_prob > 0.0: + transform += [ + HorizontalFlip(p=transform_args.random_flip.horizontal_prob) + ] + if transform_args.random_flip.vertical_prob > 0.0: + transform += [VerticalFlip(p=transform_args.random_flip.vertical_prob)] # Optionally apply color jitter. if transform_args.color_jitter is not None: diff --git a/src/lightly_train/_transforms/task_transform.py b/src/lightly_train/_transforms/task_transform.py index 43cf49310..6dc881e1d 100644 --- a/src/lightly_train/_transforms/task_transform.py +++ b/src/lightly_train/_transforms/task_transform.py @@ -37,7 +37,10 @@ def resolve_incompatible(self) -> None: class TaskTransform: transform_args_cls: type[TaskTransformArgs] - def __init__(self, transform_args: TaskTransformArgs): + def __init__( + self, + transform_args: TaskTransformArgs, + ) -> None: if not isinstance(transform_args, self.transform_args_cls): raise TypeError( f"transform_args must be of type {self.transform_args_cls.__name__}, " diff --git a/src/lightly_train/_transforms/transform.py b/src/lightly_train/_transforms/transform.py index c81c5cc64..6438eaeff 100644 --- a/src/lightly_train/_transforms/transform.py +++ b/src/lightly_train/_transforms/transform.py @@ -11,13 +11,15 @@ from collections.abc import Sequence from typing import ( Literal, + Set, Type, TypeVar, ) import pydantic +from albumentations import BasicTransform from lightly.transforms.utils import IMAGENET_NORMALIZE -from pydantic import Field +from pydantic import ConfigDict, Field from lightly_train._configs.config import PydanticConfig from lightly_train._configs.validate import no_auto @@ -57,10 +59,10 @@ class RandomFlipArgs(PydanticConfig): class RandomPhotometricDistortArgs(PydanticConfig): - brightness: tuple[float, float] = Field(strict=False, ge=0.0) - contrast: tuple[float, float] = Field(strict=False, ge=0.0) - saturation: tuple[float, float] = Field(strict=False, ge=0.0) - hue: tuple[float, float] = Field(strict=False, ge=-0.5, le=0.5) + brightness: tuple[float, float] = Field(strict=False) + contrast: tuple[float, float] = Field(strict=False) + saturation: tuple[float, float] = Field(strict=False) + hue: tuple[float, float] = Field(strict=False) prob: float = Field(ge=0.0, le=1.0) @@ -72,7 +74,7 @@ class RandomRotationArgs(PydanticConfig): class RandomZoomOutArgs(PydanticConfig): prob: float = Field(ge=0.0, le=1.0) fill: float - side_range: tuple[float, float] = Field(strict=False, ge=1.0) + side_range: tuple[float, float] = Field(strict=False) class ColorJitterArgs(PydanticConfig): @@ -147,10 +149,19 @@ def from_dict(cls, config: dict[str, list[float]]) -> NormalizeArgs: class ScaleJitterArgs(PydanticConfig): - min_scale: float - max_scale: float - num_scales: int - prob: float + sizes: Sequence[tuple[int, int]] | None + min_scale: float | None + max_scale: float | None + num_scales: int | None + prob: float = Field(ge=0.0, le=1.0) + divisible_by: int | None + + +class StopPolicyArgs(PydanticConfig): + stop_step: int + ops: Set[type[BasicTransform]] + + model_config = ConfigDict(arbitrary_types_allowed=True) class SmallestMaxSizeArgs(PydanticConfig): diff --git a/src/lightly_train/types.py b/src/lightly_train/types.py index 4580103a9..ad64a16dd 100644 --- a/src/lightly_train/types.py +++ b/src/lightly_train/types.py @@ -103,6 +103,13 @@ class ObjectDetectionDatasetItem(TypedDict): classes: Tensor # Of shape (n_boxes,) with class labels. +class ObjectDetectionBatch(TypedDict): + image_path: list[ImageFilename] # length==batch_size + image: Tensor # Tensor with shape (batch_size, 3, H, W). + bboxes: list[Tensor] # One tensor per image, each of shape (n_boxes, 4). + classes: list[Tensor] # One tensor per image, each of shape (n_boxes,). + + # Replaces torch.optim.optimizer.ParamsT # as it is only available in torch>=v2.2. # Importing it conditionally cannot make typing work for both older diff --git a/tests/_data/test_yolo_object_detection_dataset.py b/tests/_data/test_yolo_object_detection_dataset.py index 149d0fd1f..eaace8704 100644 --- a/tests/_data/test_yolo_object_detection_dataset.py +++ b/tests/_data/test_yolo_object_detection_dataset.py @@ -8,9 +8,10 @@ from __future__ import annotations from pathlib import Path +from typing import Literal -from albumentations import BboxParams, Compose, Normalize, Resize -from albumentations.pytorch.transforms import ToTensorV2 +import torch +from albumentations import BboxParams from lightly_train._data.yolo_object_detection_dataset import ( YOLOObjectDetectionDataArgs, @@ -19,45 +20,28 @@ from lightly_train._transforms.object_detection_transform import ( ObjectDetectionTransform, ObjectDetectionTransformArgs, - ObjectDetectionTransformInput, - ObjectDetectionTransformOutput, ) from lightly_train._transforms.transform import ( + ChannelDropArgs, RandomFlipArgs, RandomPhotometricDistortArgs, RandomZoomOutArgs, - ResizeArgs, + ScaleJitterArgs, + StopPolicyArgs, ) from ..helpers import create_yolo_dataset -class DummyTransform(ObjectDetectionTransform): - transform_args_cls = ObjectDetectionTransformArgs - - def __init__(self, transform_args: ObjectDetectionTransformArgs): - super().__init__(transform_args=transform_args) - self.transform = Compose( - [ - Resize(32, 32), - Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), - ToTensorV2(), - ], - bbox_params=transform_args.bbox_params, - ) - - def __call__( - self, input: ObjectDetectionTransformInput - ) -> ObjectDetectionTransformOutput: - output: ObjectDetectionTransformOutput = self.transform(**input) - return output - - class DummyTransformArgs(ObjectDetectionTransformArgs): + channel_drop: ChannelDropArgs | None = None + num_channels: int | Literal["auto"] = "auto" photometric_distort: RandomPhotometricDistortArgs | None = None random_zoom_out: RandomZoomOutArgs | None = None random_flip: RandomFlipArgs | None = None - resize: ResizeArgs | None = None + image_size: tuple[int, int] = (32, 32) + stop_policy: StopPolicyArgs | None = None + scale_jitter: ScaleJitterArgs | None = None bbox_params: BboxParams = BboxParams( format="yolo", label_fields=["class_labels"], @@ -80,23 +64,23 @@ def test__split_first(self, tmp_path: Path) -> None: train_dataset = YOLOObjectDetectionDataset( dataset_args=train_args, - transform=DummyTransform(DummyTransformArgs()), + transform=ObjectDetectionTransform(DummyTransformArgs()), image_filenames=["0.png", "1.png"], ) val_dataset = YOLOObjectDetectionDataset( dataset_args=val_args, - transform=DummyTransform(DummyTransformArgs()), + transform=ObjectDetectionTransform(DummyTransformArgs()), image_filenames=["0.png", "1.png"], ) sample = train_dataset[0] - assert sample["image"].shape == (3, 32, 32) + assert sample["image"].dtype == torch.float32 assert sample["bboxes"].shape == (1, 4) assert sample["classes"].shape == (1,) sample = val_dataset[0] - assert sample["image"].shape == (3, 32, 32) + assert sample["image"].dtype == torch.float32 assert sample["bboxes"].shape == (1, 4) assert sample["classes"].shape == (1,) @@ -115,21 +99,21 @@ def test__split_last(self, tmp_path: Path) -> None: train_dataset = YOLOObjectDetectionDataset( dataset_args=train_args, - transform=DummyTransform(DummyTransformArgs()), + transform=ObjectDetectionTransform(DummyTransformArgs()), image_filenames=["0.png", "1.png"], ) val_dataset = YOLOObjectDetectionDataset( dataset_args=val_args, - transform=DummyTransform(DummyTransformArgs()), + transform=ObjectDetectionTransform(DummyTransformArgs()), image_filenames=["0.png", "1.png"], ) sample = train_dataset[0] - assert sample["image"].shape == (3, 32, 32) + assert sample["image"].dtype == torch.float32 assert sample["bboxes"].shape == (1, 4) assert sample["classes"].shape == (1,) sample = val_dataset[0] - assert sample["image"].shape == (3, 32, 32) + assert sample["image"].dtype == torch.float32 assert sample["bboxes"].shape == (1, 4) diff --git a/tests/_transforms/test_object_detection_transform.py b/tests/_transforms/test_object_detection_transform.py new file mode 100644 index 000000000..f06a1a750 --- /dev/null +++ b/tests/_transforms/test_object_detection_transform.py @@ -0,0 +1,193 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# + +from __future__ import annotations + +import itertools + +import numpy as np +import pytest +import torch +from albumentations import BboxParams +from numpy.typing import NDArray + +from lightly_train._data.task_batch_collation import ObjectDetectionCollateFunction +from lightly_train._transforms.channel_drop import ChannelDrop +from lightly_train._transforms.object_detection_transform import ( + ObjectDetectionTransform, + ObjectDetectionTransformArgs, + ObjectDetectionTransformInput, +) +from lightly_train._transforms.transform import ( + ChannelDropArgs, + RandomFlipArgs, + RandomPhotometricDistortArgs, + RandomZoomOutArgs, + ScaleJitterArgs, + StopPolicyArgs, +) +from lightly_train.types import ObjectDetectionDatasetItem + + +def _get_channel_drop_args() -> ChannelDropArgs: + return ChannelDropArgs( + num_channels_keep=3, + weight_drop=(1.0, 1.0, 0.0, 0.0), + ) + + +def _get_random_flip_args() -> RandomFlipArgs: + return RandomFlipArgs(horizontal_prob=0.5, vertical_prob=0.5) + + +def _get_photometric_distort_args() -> RandomPhotometricDistortArgs: + return RandomPhotometricDistortArgs( + brightness=(0.8, 1.2), + contrast=(0.8, 1.2), + saturation=(0.8, 1.2), + hue=(-0.1, 0.1), + prob=0.5, + ) + + +def _get_random_zoom_out_args() -> RandomZoomOutArgs: + return RandomZoomOutArgs( + prob=0.5, + fill=0.0, + side_range=(1.0, 1.5), + ) + + +def _get_bbox_params() -> BboxParams: + return BboxParams( + format="pascal_voc", + label_fields=["class_labels"], + min_area=0, + min_visibility=0.0, + ) + + +def _get_stop_policy_args() -> StopPolicyArgs: + return StopPolicyArgs( + stop_step=500_000, + ops={ChannelDrop}, + ) + + +def _get_scale_jitter_args() -> ScaleJitterArgs: + return ScaleJitterArgs( + sizes=None, + min_scale=0.76, + max_scale=1.27, + num_scales=13, + prob=1.0, + divisible_by=14, + ) + + +def _get_image_size() -> tuple[int, int]: + return (64, 64) + + +PossibleArgsTuple = ( + [None, _get_channel_drop_args()], + [None, _get_photometric_distort_args()], + [None, _get_random_zoom_out_args()], + [None, _get_random_flip_args()], + # TODO: Lionel (09/25) Add StopPolicyArgs test cases. + [None, _get_scale_jitter_args()], +) + +possible_tuples = list(itertools.product(*PossibleArgsTuple)) + + +class TestObjectDetectionTransform: + @pytest.mark.parametrize( + "channel_drop, photometric_distort, random_zoom_out, random_flip, scale_jitter", + possible_tuples, + ) + def test___all_args_combinations( + self, + channel_drop: ChannelDropArgs | None, + photometric_distort: RandomPhotometricDistortArgs | None, + random_zoom_out: RandomZoomOutArgs | None, + random_flip: RandomFlipArgs | None, + scale_jitter: ScaleJitterArgs | None, + ) -> None: + image_size = _get_image_size() + bbox_params = _get_bbox_params() + stop_policy = None # TODO: Lionel (09/25) Pass as function argument. + transform_args = ObjectDetectionTransformArgs( + channel_drop=channel_drop, + num_channels="auto", + photometric_distort=photometric_distort, + random_zoom_out=random_zoom_out, + random_flip=random_flip, + image_size=image_size, + bbox_params=bbox_params, + stop_policy=stop_policy, + scale_jitter=scale_jitter, + ) + transform_args.resolve_auto() + transform = ObjectDetectionTransform(transform_args) + + # Create a synthetic image and bounding boxes. + num_channels = transform_args.num_channels + assert num_channels != "auto" + img: NDArray[np.uint8] = np.random.randint( + 0, 256, (128, 128, num_channels), dtype=np.uint8 + ) + bboxes = np.array([[10, 10, 50, 50]], dtype=np.float64) + class_labels = np.array([1], dtype=np.int64) + + tr_input: ObjectDetectionTransformInput = { + "image": img, + "bboxes": bboxes, + "class_labels": class_labels, + } + tr_output = transform(tr_input) + assert isinstance(tr_output, dict) + out_img = tr_output["image"] + assert isinstance(out_img, torch.Tensor) + assert out_img.dtype == torch.float32 + assert "bboxes" in tr_output + assert "class_labels" in tr_output + + def test__collation(self) -> None: + transform_args = ObjectDetectionTransformArgs( + channel_drop=_get_channel_drop_args(), + num_channels="auto", + photometric_distort=_get_photometric_distort_args(), + random_zoom_out=_get_random_zoom_out_args(), + random_flip=_get_random_flip_args(), + image_size=_get_image_size(), + bbox_params=_get_bbox_params(), + stop_policy=_get_stop_policy_args(), + scale_jitter=_get_scale_jitter_args(), + ) + transform_args.resolve_auto() + collate_fn = ObjectDetectionCollateFunction( + split="train", transform_args=transform_args + ) + + sample1: ObjectDetectionDatasetItem = { + "image_path": "img1.png", + "image": torch.randn(3, 64, 64), + "bboxes": torch.tensor([[10.0, 10.0, 50.0, 50.0]]), + "classes": torch.tensor([1]), + } + sample2: ObjectDetectionDatasetItem = { + "image_path": "img2.png", + "image": torch.randn(3, 64, 64), + "bboxes": torch.tensor([[20.0, 20.0, 40.0, 40.0]]), + "classes": torch.tensor([2]), + } + batch = [sample1, sample2] + + out = collate_fn(batch) + assert isinstance(out, dict) diff --git a/tests/_transforms/test_scale_jitter.py b/tests/_transforms/test_scale_jitter.py new file mode 100644 index 000000000..552b84e92 --- /dev/null +++ b/tests/_transforms/test_scale_jitter.py @@ -0,0 +1,244 @@ +# +# Copyright (c) Lightly AG and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +from __future__ import annotations + +import numpy as np +from albumentations import BboxParams, Compose + +from lightly_train._transforms.scale_jitter import ScaleJitter + + +class TestRandomScaleJitter: + def test__call__check_return_shapes_larger(self) -> None: + img_size = (16, 16) + img = np.random.randint(0, 255, size=(*img_size, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, size=img_size, dtype=np.uint8) + bboxes = np.array([[4, 4, 12, 12]], dtype=np.float32) + classes = np.array([1], dtype=np.int32) + + transform = ScaleJitter( + sizes=None, + target_size=img_size, + scale_range=(2.0, 4.0), + num_scales=3, + p=1.0, + ) + bbox_params = BboxParams(format="pascal_voc", label_fields=["class_labels"]) + out = transform( + image=img, + mask=mask, + bboxes=bboxes, + bbox_params=bbox_params, + class_labels=classes, + ) + + assert all(out["image"].shape[i] > img.shape[i] for i in (0, 1)) + assert all(out["image"].shape[i] > mask.shape[i] for i in (0, 1)) + assert np.array(out["class_labels"]).shape == classes.shape + assert np.array(out["bboxes"]).shape == bboxes.shape + + def test__call__check_return_shapes_smaller(self) -> None: + img_size = (16, 16) + img = np.random.randint(0, 255, size=(*img_size, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, size=img_size, dtype=np.uint8) + bboxes = np.array([[4, 4, 12, 12]], dtype=np.float32) + classes = np.array([1], dtype=np.int32) + + transform = ScaleJitter( + sizes=None, + target_size=img_size, + scale_range=(0.2, 0.7), + num_scales=3, + p=1.0, + ) + bbox_params = BboxParams(format="pascal_voc", label_fields=["class_labels"]) + out = transform( + image=img, + mask=mask, + bboxes=bboxes, + bbox_params=bbox_params, + class_labels=classes, + ) + + assert all(out["image"].shape[i] < img.shape[i] for i in (0, 1)) + assert all(out["image"].shape[i] < mask.shape[i] for i in (0, 1)) + assert np.array(out["class_labels"]).shape == classes.shape + assert np.array(out["bboxes"]).shape == bboxes.shape + + def test__call__check_return_shapes_in_sizes(self) -> None: + img_size = (16, 16) + img = np.random.randint(0, 255, size=(*img_size, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, size=img_size, dtype=np.uint8) + bboxes = np.array([[4, 4, 12, 12]], dtype=np.float32) + classes = np.array([1], dtype=np.int32) + + sizes = [(8, 8), (12, 12), (20, 20)] + transform = ScaleJitter( + sizes=sizes, + target_size=img_size, + scale_range=(0.5, 2.0), + num_scales=3, + p=1.0, + ) + bbox_params = BboxParams(format="pascal_voc", label_fields=["class_labels"]) + out = transform( + image=img, + mask=mask, + bboxes=bboxes, + bbox_params=bbox_params, + class_labels=classes, + ) + + assert out["image"].shape in [(s[0], s[1], 3) for s in sizes] + assert out["mask"].shape in [s for s in sizes] + assert np.array(out["class_labels"]).shape == classes.shape + assert np.array(out["bboxes"]).shape == bboxes.shape + + def test__call__no_transform_when_p0(self) -> None: + img_size = (8, 8) + img = np.random.randint(0, 255, size=(*img_size, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, size=img_size, dtype=np.uint8) + bboxes = np.array([[1, 1, 2, 2]], dtype=np.float32) + classes = np.array([1], dtype=np.int32) + + transform = ScaleJitter( + sizes=None, + target_size=img_size, + scale_range=(1.0, 2.0), + num_scales=3, + p=0.0, + ) + bbox_params = BboxParams(format="pascal_voc", label_fields=["class_labels"]) + out = transform( + image=img, + mask=mask, + bboxes=bboxes, + bbox_params=bbox_params, + class_labels=classes, + ) + + assert np.array_equal(out["image"], img) + assert np.array_equal(out["mask"], mask) + assert np.array_equal(out["bboxes"], bboxes) + assert np.array_equal(out["class_labels"], classes) + + def test__call__always_transform_when_p1(self) -> None: + img_size = (16, 16) + img = np.random.randint(0, 255, size=(*img_size, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, size=(16, 16), dtype=np.uint8) + bboxes = np.array([[1, 1, 2, 2]], dtype=np.float32) + classes = np.array([1], dtype=np.int32) + bbox_params = BboxParams(format="pascal_voc", label_fields=["class_labels"]) + + transform = Compose( + [ + ScaleJitter( + sizes=None, + target_size=img_size, + scale_range=(2.0, 4.0), + num_scales=2, + p=1.0, + ) + ], + bbox_params=bbox_params, + ) + out = transform( + image=img, + mask=mask, + bboxes=bboxes, + class_labels=classes, + ) + assert out["image"].shape != img.shape + assert out["mask"].shape != mask.shape + assert np.array_equal(out["class_labels"], classes) + # With scale >=2.0 the bbox has to change in Pascal VOC format. + assert not np.array_equal(out["bboxes"], bboxes) + + def test__step_seeding__deterministic(self) -> None: + img_size = (8, 8) + img = np.random.randint(0, 255, size=(*img_size, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, size=img_size, dtype=np.uint8) + bboxes = np.array([[1, 1, 2, 2]], dtype=np.float32) + classes = np.array([1], dtype=np.int32) + bbox_params = BboxParams(format="pascal_voc", label_fields=["class_labels"]) + + transform = Compose( + [ + ScaleJitter( + sizes=None, + target_size=img_size, + scale_range=(1.0, 10.0), + num_scales=10, + p=1.0, + step_seeding=True, + seed_offset=42, + ) + ], + bbox_params=bbox_params, + ) + # Set step and get deterministic idx + transform.transforms[0].step = 5 + out1 = transform( + image=img, + mask=mask, + bboxes=bboxes, + class_labels=classes, + ) + out2 = transform( + image=img, + mask=mask, + bboxes=bboxes, + class_labels=classes, + ) + assert np.array_equal(out1["image"], out2["image"]) + assert np.array_equal(out1["mask"], out2["mask"]) + assert np.array_equal(out1["bboxes"], out2["bboxes"]) + assert np.array_equal(out1["class_labels"], out2["class_labels"]) + + def test__step_seeding__different_steps(self) -> None: + img_size = (8, 8) + img = np.random.randint(0, 255, size=(*img_size, 3), dtype=np.uint8) + mask = np.random.randint(0, 255, size=img_size, dtype=np.uint8) + bboxes = np.array([[1, 1, 2, 2]], dtype=np.float64) + classes = np.array([1], dtype=np.int64) + bbox_params = BboxParams(format="pascal_voc", label_fields=["class_labels"]) + + transform = Compose( + [ + ScaleJitter( + sizes=None, + target_size=img_size, + scale_range=(1.0, 10.0), + num_scales=10, + p=1.0, + step_seeding=True, + seed_offset=42, + ) + ], + bbox_params=bbox_params, + ) + # Set step and get deterministic idx for first transform + transform.transforms[0].step = 5 + out1 = transform( + image=img, + mask=mask, + bboxes=bboxes, + class_labels=classes, + ) + # Change step and get deterministic idx for second transform + transform.transforms[0].step = 6 + out2 = transform( + image=img, + mask=mask, + bboxes=bboxes, + class_labels=classes, + ) + assert not np.array_equal(out1["image"], out2["image"]) + assert not np.array_equal(out1["mask"], out2["mask"]) + assert not np.array_equal(out1["bboxes"], out2["bboxes"]) + assert np.array_equal(out1["class_labels"], out2["class_labels"])