-
Notifications
You must be signed in to change notification settings - Fork 46
Add Obj Det Transform Pt 2 #300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
liopeer
merged 99 commits into
main
from
lionel-trn-1520-implement-object-detection-transforms-pt2
Sep 29, 2025
Merged
Changes from 79 commits
Commits
Show all changes
99 commits
Select commit
Hold shift + click to select a range
75822ec
wip commit
liopeer 78153fd
wip commit
liopeer d2e86b6
wrap up dataset
liopeer 14d3c2d
address review
liopeer 631e420
add TaskTransforms to pydantic classes
liopeer bb29ee2
Merge branch 'lionel-trn-1394-save-image-size-and-norm-info-in-model'…
liopeer 06a9eb9
add classvars to trainmodel
liopeer 8cab886
Merge branch 'main' into lionel-trn-1394-save-image-size-and-norm-inf…
liopeer ae5859b
small correction
liopeer 6bbcbb4
fix model evaluation
liopeer 6b80a9d
make transform args available in transform interface
liopeer 7aef0c4
Merge branch 'lionel-trn-1394-save-image-size-and-norm-info-in-model'…
liopeer 702aad9
wip commit
liopeer 96814bf
address review
liopeer 608d685
Merge branch 'lionel-trn-1394-save-image-size-and-norm-info-in-model'…
liopeer 9d7bd67
wip commit
liopeer be6e7d9
add unit tests
liopeer e903042
Merge branch 'main' into lionel-trn-758-add-object-detection-dataset
liopeer 743c131
fix merge
liopeer ae0a286
remove unnecessary file
liopeer 4b22be6
remove classification files
liopeer d36fec5
remove is_supported_model from TrainModel
liopeer d543562
remove is_supported_model from TrainModel
liopeer d6e3f8b
Merge branch 'lionel-trn-758-add-object-detection-dataset' of github.…
liopeer 15141dc
fix typing
liopeer df817b1
fix unit tests
liopeer 803ba59
Merge branch 'main' into lionel-trn-758-add-object-detection-dataset
liopeer e23b305
handle empty lines in yolo labels
liopeer 37863dd
explicit dtypes for arrays
liopeer 102d508
move logic to dataargs
liopeer aee9ac0
pydantic convention adherence
liopeer ffc33ea
add obj detection skeleton
liopeer 1e76743
fix failing unit tests
liopeer 42f0c92
fix formatting
liopeer e6ef891
address review
liopeer c654d0d
fix small error
liopeer 1fb79e1
replace only first "images" occurrence
liopeer c1a60ea
Merge branch 'main' into lionel-trn-758-add-object-detection-dataset
liopeer 438f935
small error
liopeer f29a377
small fix
liopeer c51158b
from future import annotations
liopeer ac35b6c
use typing_extensions instead of typing
liopeer 6e4b67b
fix unit tests
liopeer 6a5254e
DINOv3EoMTSemanticSegmentation
liopeer e001c6f
fix merge
liopeer 4c09c46
add photometric distortion transform
liopeer 2a9a874
add random zoom out
liopeer c21ad54
fix typing in tests
liopeer c2c419a
fix typing
liopeer e05ed19
add future annotations
liopeer bedec74
add custom RandomOrder
liopeer c8997ac
fix py38 import of ToTensorv2
liopeer b4101b8
Merge branch 'main' into lionel-trn-759-implement-object-detection-tr…
liopeer 7808cc0
less restrictive typing
liopeer 35bbaf0
don't allow 0 probability in compositions due to incompatibility with…
liopeer 8ece5c3
handle different albumentations versions
liopeer 2ce0a6e
remove p0 test since not compatible with older albumentations versions
liopeer f1cd827
add tests for RandomOrder transform
liopeer 8ec3261
Merge branch 'lionel-trn-759-implement-object-detection-transforms' o…
liopeer d56ca2d
delete empty file
liopeer 65aa004
fix formatting
liopeer e490818
fix tests failing on bbox tuples
liopeer f0b53fd
address issues with randomorder
liopeer fad4b8c
make randomzoomout work with all albumentation versions
liopeer 75f4d8a
fix typing
liopeer 39f2f16
implement obj det transform
liopeer 2826cee
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer 0829369
merge
liopeer 902bcd5
fix merge
liopeer 3f2c72b
use step instead of epoch
liopeer 43d3559
readd TODO
liopeer c183fe8
fix import
liopeer a033f25
fix unit tests
liopeer 1c09bb1
fix typing
liopeer 82b0547
simplify tests
liopeer d46a963
wip commit
liopeer 7d8cebe
use scalejitter in collation instead
liopeer abc7ba8
add comment about step increase
liopeer c7e5225
use sharedstep in transform
liopeer 7bf7c30
remove unused list of collate functions
liopeer f7f1c5d
add comment
liopeer 49d8c68
formatting
liopeer 56fd793
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer ba99a38
fix import
liopeer 23bc8b7
remove markdown formatting
liopeer 06b5fa4
fix typing in 3.8
liopeer 04e5751
change to mp.Manager
liopeer 4ca6116
remove shared step
liopeer 5f71376
remove stoppolicy test cases
liopeer dd714cf
small refactors
liopeer 78d4a08
revert formatting
liopeer ef2706d
use fixed scales for obj det transform
liopeer 2bf7fd9
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer d046fc0
add todo
liopeer 3c2918b
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer 0b9ba5a
remove step seeding
liopeer 8aef613
Merge branch 'lionel-trn-1520-implement-object-detection-transforms-p…
liopeer 883d9d5
remove stop_policy from LT-DETR
liopeer bf5c3d5
Merge branch 'main' into lionel-trn-1520-implement-object-detection-t…
liopeer File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| # | ||
| # Copyright (c) Lightly AG and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| from multiprocessing import Value | ||
|
|
||
|
|
||
| class WorkerSharedStep: | ||
liopeer marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """A class to share the current step between dataloader workers.""" | ||
|
|
||
| def __init__(self, step: int) -> None: | ||
| self._step_value = Value("i", step) | ||
|
|
||
| @property | ||
| def step(self) -> int: | ||
| with self._step_value.get_lock(): | ||
| val = self._step_value.value | ||
| assert isinstance(val, int) | ||
| return val | ||
|
|
||
| @step.setter | ||
| def step(self, step: int) -> None: | ||
| with self._step_value.get_lock(): | ||
| self._step_value.value = step | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,134 @@ | ||
| # | ||
| # Copyright (c) Lightly AG and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
| from __future__ import annotations | ||
|
|
||
| from typing import Literal | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch import Tensor | ||
|
|
||
| from lightly_train._transforms.object_detection_transform import ( | ||
| ObjectDetectionTransformArgs, | ||
| ) | ||
| from lightly_train._transforms.scale_jitter import ScaleJitter | ||
| from lightly_train._transforms.task_transform import TaskTransformArgs | ||
| from lightly_train.types import ( | ||
| MaskSemanticSegmentationBatch, | ||
| MaskSemanticSegmentationDatasetItem, | ||
| ObjectDetectionBatch, | ||
| ObjectDetectionDatasetItem, | ||
| ) | ||
|
|
||
|
|
||
| class BaseCollateFunction: | ||
| def __init__( | ||
| self, split: Literal["train", "val"], transform_args: TaskTransformArgs | ||
| ): | ||
| self.split = split | ||
| self.transform_args = transform_args | ||
|
|
||
|
|
||
| class MaskSemanticSegmentationCollateFunction(BaseCollateFunction): | ||
| def __call__( | ||
| self, batch: list[MaskSemanticSegmentationDatasetItem] | ||
| ) -> MaskSemanticSegmentationBatch: | ||
| # Prepare the batch without any stacking. | ||
| images = [item["image"] for item in batch] | ||
| masks = [item["mask"] for item in batch] | ||
|
|
||
| out: MaskSemanticSegmentationBatch = { | ||
| "image_path": [item["image_path"] for item in batch], | ||
| # Stack images during training as they all have the same shape. | ||
| # During validation every image can have a different shape. | ||
| "image": torch.stack(images) if self.split == "train" else images, | ||
| "mask": torch.stack(masks) if self.split == "train" else masks, | ||
| "binary_masks": [item["binary_masks"] for item in batch], | ||
| } | ||
|
|
||
| return out | ||
|
|
||
|
|
||
| class ObjectDetectionCollateFunction(BaseCollateFunction): | ||
| def __init__( | ||
| self, split: Literal["train", "val"], transform_args: TaskTransformArgs | ||
| ): | ||
| super().__init__(split, transform_args) | ||
| assert isinstance(transform_args, ObjectDetectionTransformArgs) | ||
| self.scale_jitter: ScaleJitter | None | ||
| if transform_args.scale_jitter is not None: | ||
| self.scale_jitter = ScaleJitter( | ||
| target_size=transform_args.image_size, | ||
| scale_range=( | ||
| transform_args.scale_jitter.min_scale, | ||
| transform_args.scale_jitter.max_scale, | ||
| ), | ||
| num_scales=transform_args.scale_jitter.num_scales, | ||
| divisible_by=transform_args.scale_jitter.divisible_by, | ||
| p=transform_args.scale_jitter.prob, | ||
| step_seeding=transform_args.scale_jitter.step_seeding, | ||
| seed_offset=transform_args.scale_jitter.seed_offset, | ||
| ) | ||
| else: | ||
| self.scale_jitter = None | ||
|
|
||
| def __call__(self, batch: list[ObjectDetectionDatasetItem]) -> ObjectDetectionBatch: | ||
| if self.scale_jitter is not None: | ||
| # Turn into numpy again. | ||
| batch_np = [ | ||
| { | ||
| "image_path": item["image_path"], | ||
| "image": item["image"].numpy(), | ||
| "bboxes": item["bboxes"].numpy(), | ||
| "classes": item["classes"].numpy(), | ||
| } | ||
| for item in batch | ||
| ] | ||
|
|
||
| # Apply transform. | ||
| seed = np.random.randint(0, 1_000_000) | ||
| self.scale_jitter.global_step = seed | ||
| images: list[Tensor] = [] | ||
| bboxes: list[Tensor] = [] | ||
| classes: list[Tensor] = [] | ||
| for item in batch_np: | ||
| out = self.scale_jitter( | ||
| image=item["image"], | ||
| bboxes=item["bboxes"], | ||
| class_labels=item["classes"], | ||
| ) | ||
| images.append(out["image"]) | ||
| bboxes.append(out["bboxes"]) | ||
| classes.append(out["class_labels"]) | ||
|
|
||
| # Turn back into torch tensors. | ||
| images = [torch.from_numpy(img).to(torch.float32) for img in images] | ||
| bboxes = [torch.from_numpy(bbox).to(torch.float32) for bbox in bboxes] | ||
| classes = [torch.from_numpy(cls).to(torch.int64) for cls in classes] | ||
|
|
||
| out_: ObjectDetectionBatch = { | ||
| "image_path": [item["image_path"] for item in batch], | ||
| "image": torch.stack(images), | ||
| "bboxes": bboxes, | ||
| "classes": classes, | ||
| } | ||
| return out_ | ||
| else: | ||
| out_ = { | ||
| "image_path": [item["image_path"] for item in batch], | ||
| "image": torch.stack([item["image"] for item in batch]), | ||
| "bboxes": [item["bboxes"] for item in batch], | ||
| "classes": [item["classes"] for item in batch], | ||
| } | ||
| return out_ | ||
|
|
||
|
|
||
| COLLATE_FN_CLS_LIST = [ | ||
| MaskSemanticSegmentationCollateFunction, | ||
| ObjectDetectionCollateFunction, | ||
| ] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.