Skip to content
Merged
Show file tree
Hide file tree
Changes from 79 commits
Commits
Show all changes
99 commits
Select commit Hold shift + click to select a range
75822ec
wip commit
liopeer Aug 7, 2025
78153fd
wip commit
liopeer Aug 8, 2025
d2e86b6
wrap up dataset
liopeer Aug 8, 2025
14d3c2d
address review
liopeer Aug 8, 2025
631e420
add TaskTransforms to pydantic classes
liopeer Aug 13, 2025
bb29ee2
Merge branch 'lionel-trn-1394-save-image-size-and-norm-info-in-model'…
liopeer Aug 13, 2025
06a9eb9
add classvars to trainmodel
liopeer Aug 14, 2025
8cab886
Merge branch 'main' into lionel-trn-1394-save-image-size-and-norm-inf…
liopeer Aug 14, 2025
ae5859b
small correction
liopeer Aug 14, 2025
6bbcbb4
fix model evaluation
liopeer Aug 14, 2025
6b80a9d
make transform args available in transform interface
liopeer Aug 14, 2025
7aef0c4
Merge branch 'lionel-trn-1394-save-image-size-and-norm-info-in-model'…
liopeer Aug 14, 2025
702aad9
wip commit
liopeer Aug 14, 2025
96814bf
address review
liopeer Aug 14, 2025
608d685
Merge branch 'lionel-trn-1394-save-image-size-and-norm-info-in-model'…
liopeer Aug 14, 2025
9d7bd67
wip commit
liopeer Aug 14, 2025
be6e7d9
add unit tests
liopeer Aug 14, 2025
e903042
Merge branch 'main' into lionel-trn-758-add-object-detection-dataset
liopeer Aug 18, 2025
743c131
fix merge
liopeer Aug 18, 2025
ae0a286
remove unnecessary file
liopeer Aug 18, 2025
4b22be6
remove classification files
liopeer Aug 18, 2025
d36fec5
remove is_supported_model from TrainModel
liopeer Aug 18, 2025
d543562
remove is_supported_model from TrainModel
liopeer Aug 18, 2025
d6e3f8b
Merge branch 'lionel-trn-758-add-object-detection-dataset' of github.…
liopeer Aug 18, 2025
15141dc
fix typing
liopeer Aug 18, 2025
df817b1
fix unit tests
liopeer Aug 18, 2025
803ba59
Merge branch 'main' into lionel-trn-758-add-object-detection-dataset
liopeer Aug 22, 2025
e23b305
handle empty lines in yolo labels
liopeer Aug 22, 2025
37863dd
explicit dtypes for arrays
liopeer Aug 22, 2025
102d508
move logic to dataargs
liopeer Aug 22, 2025
aee9ac0
pydantic convention adherence
liopeer Aug 22, 2025
ffc33ea
add obj detection skeleton
liopeer Aug 22, 2025
1e76743
fix failing unit tests
liopeer Aug 25, 2025
42f0c92
fix formatting
liopeer Aug 25, 2025
e6ef891
address review
liopeer Aug 27, 2025
c654d0d
fix small error
liopeer Aug 27, 2025
1fb79e1
replace only first "images" occurrence
liopeer Aug 27, 2025
c1a60ea
Merge branch 'main' into lionel-trn-758-add-object-detection-dataset
liopeer Aug 27, 2025
438f935
small error
liopeer Aug 27, 2025
f29a377
small fix
liopeer Aug 27, 2025
c51158b
from future import annotations
liopeer Aug 27, 2025
ac35b6c
use typing_extensions instead of typing
liopeer Aug 27, 2025
6e4b67b
fix unit tests
liopeer Aug 27, 2025
6a5254e
DINOv3EoMTSemanticSegmentation
liopeer Sep 12, 2025
e001c6f
fix merge
liopeer Sep 12, 2025
4c09c46
add photometric distortion transform
liopeer Sep 12, 2025
2a9a874
add random zoom out
liopeer Sep 12, 2025
c21ad54
fix typing in tests
liopeer Sep 12, 2025
c2c419a
fix typing
liopeer Sep 12, 2025
e05ed19
add future annotations
liopeer Sep 12, 2025
bedec74
add custom RandomOrder
liopeer Sep 15, 2025
c8997ac
fix py38 import of ToTensorv2
liopeer Sep 15, 2025
b4101b8
Merge branch 'main' into lionel-trn-759-implement-object-detection-tr…
liopeer Sep 15, 2025
7808cc0
less restrictive typing
liopeer Sep 15, 2025
35bbaf0
don't allow 0 probability in compositions due to incompatibility with…
liopeer Sep 15, 2025
8ece5c3
handle different albumentations versions
liopeer Sep 15, 2025
2ce0a6e
remove p0 test since not compatible with older albumentations versions
liopeer Sep 15, 2025
f1cd827
add tests for RandomOrder transform
liopeer Sep 15, 2025
8ec3261
Merge branch 'lionel-trn-759-implement-object-detection-transforms' o…
liopeer Sep 15, 2025
d56ca2d
delete empty file
liopeer Sep 15, 2025
65aa004
fix formatting
liopeer Sep 15, 2025
e490818
fix tests failing on bbox tuples
liopeer Sep 16, 2025
f0b53fd
address issues with randomorder
liopeer Sep 16, 2025
fad4b8c
make randomzoomout work with all albumentation versions
liopeer Sep 16, 2025
75f4d8a
fix typing
liopeer Sep 16, 2025
39f2f16
implement obj det transform
liopeer Sep 16, 2025
2826cee
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer Sep 16, 2025
0829369
merge
liopeer Sep 16, 2025
902bcd5
fix merge
liopeer Sep 16, 2025
3f2c72b
use step instead of epoch
liopeer Sep 16, 2025
43d3559
readd TODO
liopeer Sep 16, 2025
c183fe8
fix import
liopeer Sep 16, 2025
a033f25
fix unit tests
liopeer Sep 16, 2025
1c09bb1
fix typing
liopeer Sep 18, 2025
82b0547
simplify tests
liopeer Sep 18, 2025
d46a963
wip commit
liopeer Sep 22, 2025
7d8cebe
use scalejitter in collation instead
liopeer Sep 25, 2025
abc7ba8
add comment about step increase
liopeer Sep 25, 2025
c7e5225
use sharedstep in transform
liopeer Sep 25, 2025
7bf7c30
remove unused list of collate functions
liopeer Sep 25, 2025
f7f1c5d
add comment
liopeer Sep 25, 2025
49d8c68
formatting
liopeer Sep 25, 2025
56fd793
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer Sep 25, 2025
ba99a38
fix import
liopeer Sep 25, 2025
23bc8b7
remove markdown formatting
liopeer Sep 25, 2025
06b5fa4
fix typing in 3.8
liopeer Sep 25, 2025
04e5751
change to mp.Manager
liopeer Sep 28, 2025
4ca6116
remove shared step
liopeer Sep 29, 2025
5f71376
remove stoppolicy test cases
liopeer Sep 29, 2025
dd714cf
small refactors
liopeer Sep 29, 2025
78d4a08
revert formatting
liopeer Sep 29, 2025
ef2706d
use fixed scales for obj det transform
liopeer Sep 29, 2025
2bf7fd9
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer Sep 29, 2025
d046fc0
add todo
liopeer Sep 29, 2025
3c2918b
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer Sep 29, 2025
0b9ba5a
remove step seeding
liopeer Sep 29, 2025
8aef613
Merge branch 'lionel-trn-1520-implement-object-detection-transforms-p…
liopeer Sep 29, 2025
883d9d5
remove stop_policy from LT-DETR
liopeer Sep 29, 2025
bf5c3d5
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer Sep 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions src/lightly_train/_commands/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from lightly_train._data.mask_semantic_segmentation_dataset import (
MaskSemanticSegmentationDataArgs,
)
from lightly_train._data.task_dataset import TaskDataset
from lightly_train._loggers.task_logger_args import TaskLoggerArgs
from lightly_train._task_checkpoint import TaskSaveCheckpointArgs
from lightly_train._task_models.train_model import TrainModelArgs
Expand Down Expand Up @@ -247,13 +248,13 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
) as train_mmap_filepath, helpers.get_dataset_temp_mmap_path(
fabric=fabric, data=config.data.val.images
) as val_mmap_filepath:
train_dataset = helpers.get_dataset(
train_dataset: TaskDataset = helpers.get_dataset(
fabric=fabric,
dataset_args=config.data.get_train_args(),
transform=train_transform,
mmap_filepath=train_mmap_filepath,
)
val_dataset = helpers.get_dataset(
val_dataset: TaskDataset = helpers.get_dataset(
fabric=fabric,
dataset_args=config.data.get_val_args(),
transform=val_transform,
Expand Down Expand Up @@ -294,6 +295,7 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
train_dataloader = helpers.get_train_dataloader(
fabric=fabric,
dataset=train_dataset,
transform_args=train_transform_args,
batch_size=config.batch_size,
num_workers=config.num_workers,
loader_args=config.loader_args,
Expand All @@ -302,6 +304,7 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
val_dataloader = helpers.get_val_dataloader(
fabric=fabric,
dataset=val_dataset,
transform_args=val_transform_args,
batch_size=config.batch_size,
num_workers=config.num_workers,
loader_args=config.loader_args,
Expand Down Expand Up @@ -353,6 +356,13 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
if config.resume_interrupted:
helpers.load_checkpoint(fabric=fabric, out_dir=out_dir, state=state)

# Set the global_step in the transform (has to be done after loading potential
# checkpoint).
assert isinstance(train_dataloader.dataset, TaskDataset)
assert isinstance(val_dataloader.dataset, TaskDataset)
train_dataloader.dataset.transform.global_step = state["step"]
val_dataloader.dataset.transform.global_step = state["step"]

# TODO(Guarin, 07/25): Replace with infinite batch sampler instead to avoid
# reloading dataloader after every epoch? Is this preferred over persistent workers?
infinite_train_dataloader = InfiniteCycleIterator(iterable=train_dataloader)
Expand Down Expand Up @@ -393,6 +403,10 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
optimizer.zero_grad()
scheduler.step()

# Update the global step in the transform.
train_dataloader.dataset.transform.global_step = state["step"]
val_dataloader.dataset.transform.global_step = state["step"]

if is_log_step or is_last_step:
train_log_dict = helpers.compute_metrics(train_result.log_dict)
helpers.log_step(
Expand Down
46 changes: 15 additions & 31 deletions src/lightly_train/_commands/train_task_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,16 @@
import hashlib
import json
import logging
from functools import partial
from json import JSONEncoder
from pathlib import Path
from typing import Any, Generator, Iterable, Literal, Mapping

import torch
from filelock import FileLock
from lightning_fabric import Fabric
from lightning_fabric import utilities as fabric_utilities
from lightning_fabric.loggers.logger import Logger as FabricLogger
from torch import Tensor
from torch.utils.data import DataLoader, Dataset
from torch.utils.data import DataLoader

from lightly_train._configs import validate
from lightly_train._data import cache
Expand All @@ -32,9 +30,9 @@
Primitive,
)
from lightly_train._data.mask_semantic_segmentation_dataset import (
MaskSemanticSegmentationDataset,
MaskSemanticSegmentationDatasetArgs,
)
from lightly_train._data.task_dataset import TaskDataset
from lightly_train._env import Env
from lightly_train._loggers.mlflow import MLFlowLogger
from lightly_train._loggers.task_logger_args import TaskLoggerArgs
Expand Down Expand Up @@ -62,8 +60,6 @@
TaskTransformArgs,
)
from lightly_train.types import (
MaskSemanticSegmentationBatch,
MaskSemanticSegmentationDatasetItem,
PathLike,
TaskDatasetItem,
)
Expand Down Expand Up @@ -414,7 +410,7 @@ def get_dataset(
dataset_args: MaskSemanticSegmentationDatasetArgs,
transform: TaskTransform,
mmap_filepath: Path,
) -> MaskSemanticSegmentationDataset:
) -> TaskDataset:
image_info = dataset_args.list_image_info()

dataset_cls = dataset_args.get_dataset_cls()
Expand All @@ -431,43 +427,27 @@ def get_dataset(
)


# TODO(Guarin, 08/25): Move this function to the _data module.
def collate_fn(
batch: list[MaskSemanticSegmentationDatasetItem], split: str
) -> MaskSemanticSegmentationBatch:
# Prepare the batch without any stacking.
images = [item["image"] for item in batch]
masks = [item["mask"] for item in batch]

out: MaskSemanticSegmentationBatch = {
"image_path": [item["image_path"] for item in batch],
# Stack images during training as they all have the same shape.
# During validation every image can have a different shape.
"image": torch.stack(images) if split == "train" else images,
"mask": torch.stack(masks) if split == "train" else masks,
"binary_masks": [item["binary_masks"] for item in batch],
}

return out


def get_train_dataloader(
fabric: Fabric,
dataset: Dataset[TaskDatasetItem],
dataset: TaskDataset,
transform_args: TaskTransformArgs,
batch_size: int,
num_workers: int,
loader_args: dict[str, Any] | None = None,
) -> DataLoader[TaskDatasetItem]:
timeout = Env.LIGHTLY_TRAIN_DATALOADER_TIMEOUT_SEC.value if num_workers > 0 else 0
# TODO(Guarin, 07/25): Persistent workers by default?
collate_fn = dataset.batch_collate_fn_cls(
split="train", transform_args=transform_args
)
dataloader_kwargs: dict[str, Any] = dict(
dataset=dataset,
batch_size=batch_size // fabric.world_size,
shuffle=True,
num_workers=num_workers,
drop_last=True,
timeout=timeout,
collate_fn=partial(collate_fn, split="train"),
collate_fn=collate_fn,
)
if loader_args is not None:
logger.debug(f"Using additional dataloader arguments {loader_args}.")
Expand All @@ -481,20 +461,24 @@ def get_train_dataloader(

def get_val_dataloader(
fabric: Fabric,
dataset: Dataset[TaskDatasetItem],
dataset: TaskDataset,
transform_args: TaskTransformArgs,
batch_size: int,
num_workers: int,
loader_args: dict[str, Any] | None = None,
) -> DataLoader[TaskDatasetItem]:
timeout = Env.LIGHTLY_TRAIN_DATALOADER_TIMEOUT_SEC.value if num_workers > 0 else 0
collate_fn = dataset.batch_collate_fn_cls(
split="val", transform_args=transform_args
)
dataloader_kwargs: dict[str, Any] = dict(
dataset=dataset,
batch_size=batch_size // fabric.world_size,
shuffle=False,
num_workers=num_workers,
drop_last=False,
timeout=timeout,
collate_fn=partial(collate_fn, split="validation"),
collate_fn=collate_fn,
)
if loader_args is not None:
logger.debug(f"Using additional dataloader arguments {loader_args}.")
Expand Down
27 changes: 27 additions & 0 deletions src/lightly_train/_data/dataloader_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#
# Copyright (c) Lightly AG and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from multiprocessing import Value


class WorkerSharedStep:
"""A class to share the current step between dataloader workers."""

def __init__(self, step: int) -> None:
self._step_value = Value("i", step)

@property
def step(self) -> int:
with self._step_value.get_lock():
val = self._step_value.value
assert isinstance(val, int)
return val

@step.setter
def step(self, step: int) -> None:
with self._step_value.get_lock():
self._step_value.value = step
14 changes: 11 additions & 3 deletions src/lightly_train/_data/mask_semantic_segmentation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@
import torch
from pydantic import AliasChoices, Field, TypeAdapter, field_validator
from torch import Tensor
from torch.utils.data import Dataset

from lightly_train._configs.config import PydanticConfig
from lightly_train._data import file_helpers
from lightly_train._data.file_helpers import ImageMode
from lightly_train._data.task_batch_collation import (
BaseCollateFunction,
MaskSemanticSegmentationCollateFunction,
)
from lightly_train._data.task_data_args import TaskDataArgs
from lightly_train._data.task_dataset import TaskDataset
from lightly_train._env import Env
from lightly_train._transforms.semantic_segmentation_transform import (
SemanticSegmentationTransform,
Expand Down Expand Up @@ -55,16 +59,20 @@ class MultiChannelClassInfo(PydanticConfig):
ClassInfo = Union[MultiChannelClassInfo, SingleChannelClassInfo]


class MaskSemanticSegmentationDataset(Dataset[MaskSemanticSegmentationDatasetItem]):
class MaskSemanticSegmentationDataset(TaskDataset):
batch_collate_fn_cls: ClassVar[type[BaseCollateFunction]] = (
MaskSemanticSegmentationCollateFunction
)

def __init__(
self,
dataset_args: MaskSemanticSegmentationDatasetArgs,
image_info: Sequence[dict[str, str]],
transform: SemanticSegmentationTransform,
):
super().__init__(transform=transform)
self.args = dataset_args
self.filepaths = image_info
self.transform = transform
self.ignore_index = dataset_args.ignore_index

# Get the class mapping.
Expand Down
134 changes: 134 additions & 0 deletions src/lightly_train/_data/task_batch_collation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
#
# Copyright (c) Lightly AG and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations

from typing import Literal

import numpy as np
import torch
from torch import Tensor

from lightly_train._transforms.object_detection_transform import (
ObjectDetectionTransformArgs,
)
from lightly_train._transforms.scale_jitter import ScaleJitter
from lightly_train._transforms.task_transform import TaskTransformArgs
from lightly_train.types import (
MaskSemanticSegmentationBatch,
MaskSemanticSegmentationDatasetItem,
ObjectDetectionBatch,
ObjectDetectionDatasetItem,
)


class BaseCollateFunction:
def __init__(
self, split: Literal["train", "val"], transform_args: TaskTransformArgs
):
self.split = split
self.transform_args = transform_args


class MaskSemanticSegmentationCollateFunction(BaseCollateFunction):
def __call__(
self, batch: list[MaskSemanticSegmentationDatasetItem]
) -> MaskSemanticSegmentationBatch:
# Prepare the batch without any stacking.
images = [item["image"] for item in batch]
masks = [item["mask"] for item in batch]

out: MaskSemanticSegmentationBatch = {
"image_path": [item["image_path"] for item in batch],
# Stack images during training as they all have the same shape.
# During validation every image can have a different shape.
"image": torch.stack(images) if self.split == "train" else images,
"mask": torch.stack(masks) if self.split == "train" else masks,
"binary_masks": [item["binary_masks"] for item in batch],
}

return out


class ObjectDetectionCollateFunction(BaseCollateFunction):
def __init__(
self, split: Literal["train", "val"], transform_args: TaskTransformArgs
):
super().__init__(split, transform_args)
assert isinstance(transform_args, ObjectDetectionTransformArgs)
self.scale_jitter: ScaleJitter | None
if transform_args.scale_jitter is not None:
self.scale_jitter = ScaleJitter(
target_size=transform_args.image_size,
scale_range=(
transform_args.scale_jitter.min_scale,
transform_args.scale_jitter.max_scale,
),
num_scales=transform_args.scale_jitter.num_scales,
divisible_by=transform_args.scale_jitter.divisible_by,
p=transform_args.scale_jitter.prob,
step_seeding=transform_args.scale_jitter.step_seeding,
seed_offset=transform_args.scale_jitter.seed_offset,
)
else:
self.scale_jitter = None

def __call__(self, batch: list[ObjectDetectionDatasetItem]) -> ObjectDetectionBatch:
if self.scale_jitter is not None:
# Turn into numpy again.
batch_np = [
{
"image_path": item["image_path"],
"image": item["image"].numpy(),
"bboxes": item["bboxes"].numpy(),
"classes": item["classes"].numpy(),
}
for item in batch
]

# Apply transform.
seed = np.random.randint(0, 1_000_000)
self.scale_jitter.global_step = seed
images: list[Tensor] = []
bboxes: list[Tensor] = []
classes: list[Tensor] = []
for item in batch_np:
out = self.scale_jitter(
image=item["image"],
bboxes=item["bboxes"],
class_labels=item["classes"],
)
images.append(out["image"])
bboxes.append(out["bboxes"])
classes.append(out["class_labels"])

# Turn back into torch tensors.
images = [torch.from_numpy(img).to(torch.float32) for img in images]
bboxes = [torch.from_numpy(bbox).to(torch.float32) for bbox in bboxes]
classes = [torch.from_numpy(cls).to(torch.int64) for cls in classes]

out_: ObjectDetectionBatch = {
"image_path": [item["image_path"] for item in batch],
"image": torch.stack(images),
"bboxes": bboxes,
"classes": classes,
}
return out_
else:
out_ = {
"image_path": [item["image_path"] for item in batch],
"image": torch.stack([item["image"] for item in batch]),
"bboxes": [item["bboxes"] for item in batch],
"classes": [item["classes"] for item in batch],
}
return out_


COLLATE_FN_CLS_LIST = [
MaskSemanticSegmentationCollateFunction,
ObjectDetectionCollateFunction,
]
Loading
Loading