Skip to content

Commit 2318bc6

Browse files
Multimodal: Link caption to Sample instead of Image (#36)
1 parent 4894408 commit 2318bc6

File tree

7 files changed

+24
-29
lines changed

7 files changed

+24
-29
lines changed

lightly_studio/src/lightly_studio/api/routes/api/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def read_samples(
7979
sample_id=image.sample_id,
8080
dataset_id=image.dataset_id,
8181
annotations=image.annotations,
82-
captions=image.captions,
82+
captions=image.sample.captions,
8383
tags=image.sample.tags,
8484
metadata_dict=image.sample.metadata_dict,
8585
width=image.width,
@@ -128,7 +128,7 @@ def read_sample(
128128
sample_id=image.sample_id,
129129
dataset_id=image.dataset_id,
130130
annotations=image.annotations,
131-
captions=image.captions,
131+
captions=image.sample.captions,
132132
tags=image.sample.tags,
133133
metadata_dict=image.sample.metadata_dict,
134134
width=image.width,

lightly_studio/src/lightly_studio/models/caption.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sqlmodel import Field, Relationship, SQLModel
1111

1212
if TYPE_CHECKING:
13-
from lightly_studio.models.image import ImageTable
13+
from lightly_studio.models.sample import SampleTable
1414

1515

1616
class CaptionTable(SQLModel, table=True):
@@ -22,9 +22,9 @@ class CaptionTable(SQLModel, table=True):
2222

2323
caption_id: UUID = Field(default_factory=uuid4, primary_key=True)
2424
dataset_id: UUID = Field(foreign_key="dataset.dataset_id")
25-
sample_id: UUID = Field(foreign_key="image.sample_id")
25+
sample_id: UUID = Field(foreign_key="sample.sample_id")
2626

27-
sample: Mapped[Optional["ImageTable"]] = Relationship(
27+
sample: Mapped["SampleTable"] = Relationship(
2828
back_populates="captions",
2929
sa_relationship_kwargs={"lazy": "select"},
3030
)
@@ -40,11 +40,10 @@ class CaptionCreate(SQLModel):
4040
text: str
4141

4242

43-
class CaptionImageView(SQLModel):
43+
class CaptionSampleView(SQLModel):
4444
"""Sample class for caption view."""
4545

46-
file_path_abs: str
47-
file_name: str
46+
# TODO(Michal, 10/2025): Remove this class and use CaptionView instead.
4847
dataset_id: UUID
4948
sample_id: UUID
5049

@@ -61,7 +60,7 @@ class CaptionView(SQLModel):
6160
class CaptionDetailsView(CaptionView):
6261
"""Response model for caption."""
6362

64-
sample: CaptionImageView
63+
sample: CaptionSampleView
6564

6665

6766
class CaptionsListView(BaseModel):

lightly_studio/src/lightly_studio/models/image.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,12 @@
1616
from lightly_studio.models.annotation.annotation_base import (
1717
AnnotationBaseTable,
1818
)
19-
from lightly_studio.models.caption import CaptionTable
2019
from lightly_studio.models.metadata import (
21-
SampleMetadataTable,
2220
SampleMetadataView,
2321
)
2422
from lightly_studio.models.sample import SampleTable
2523
else:
2624
AnnotationBaseTable = object
27-
CaptionTable = object
28-
SampleMetadataTable = object
2925
SampleTable = object
3026
SampleMetadataView = object
3127

@@ -65,9 +61,6 @@ class ImageTable(ImageBase, table=True):
6561
annotations: Mapped[List["AnnotationBaseTable"]] = Relationship(
6662
back_populates="sample",
6763
)
68-
captions: Mapped[List["CaptionTable"]] = Relationship(
69-
back_populates="sample",
70-
)
7164

7265
sample: Mapped["SampleTable"] = Relationship()
7366

@@ -96,13 +89,13 @@ class ImageViewTag(SQLModel):
9689
sample_id: UUID
9790
dataset_id: UUID
9891
annotations: List["AnnotationView"]
99-
captions: List[CaptionView] = []
10092
width: int
10193
height: int
10294

10395
# TODO(Michal, 10/2025): Add SampleView to ImageView, don't expose these fields directly.
10496
tags: List[ImageViewTag]
10597
metadata_dict: Optional["SampleMetadataView"] = None
98+
captions: List[CaptionView] = []
10699

107100

108101
class ImageViewsWithCount(BaseModel):

lightly_studio/src/lightly_studio/models/sample.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lightly_studio.resolvers import metadata_resolver
1111

1212
if TYPE_CHECKING:
13+
from lightly_studio.models.caption import CaptionTable, CaptionView
1314
from lightly_studio.models.metadata import (
1415
SampleMetadataTable,
1516
SampleMetadataView,
@@ -21,6 +22,8 @@
2122
SampleEmbeddingTable = object
2223
SampleMetadataTable = object
2324
SampleMetadataView = object
25+
CaptionTable = object
26+
CaptionView = object
2427

2528

2629
class SampleTagLinkTable(SQLModel, table=True):
@@ -58,6 +61,7 @@ class SampleTable(SampleBase, table=True):
5861
)
5962
embeddings: Mapped[List["SampleEmbeddingTable"]] = Relationship(back_populates="sample")
6063
metadata_dict: "SampleMetadataTable" = Relationship(back_populates="sample")
64+
captions: Mapped[List["CaptionTable"]] = Relationship(back_populates="sample")
6165

6266
# TODO(Michal, 9/2025): Remove this function in favour of Sample.metadata.
6367
def __getitem__(self, key: str) -> Any:
@@ -118,3 +122,4 @@ class SampleView(SampleBase):
118122

119123
tags: List["TagTable"] = []
120124
metadata_dict: Optional["SampleMetadataView"] = None
125+
captions: List[CaptionView] = []

lightly_studio/src/lightly_studio/resolvers/image_resolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def get_all_by_dataset_id( # noqa: PLR0913
156156
joinedload(SampleTable.tags),
157157
# Ignore type checker error below as it's a false positive caused by TYPE_CHECKING.
158158
joinedload(SampleTable.metadata_dict), # type: ignore[arg-type]
159+
selectinload(SampleTable.captions),
159160
),
160-
selectinload(ImageTable.captions),
161161
)
162162
.where(ImageTable.dataset_id == dataset_id)
163163
)

lightly_studio/tests/core/test_add_samples.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from sqlmodel import Session
1414

1515
from lightly_studio.core import add_samples
16-
from lightly_studio.models.image import ImageTable
16+
from lightly_studio.models.sample import SampleTable
1717
from lightly_studio.resolvers import caption_resolver, image_resolver
1818
from tests.helpers_resolvers import create_dataset
1919

@@ -123,13 +123,13 @@ def test_load_into_dataset_from_coco_captions(db_session: Session, tmp_path: Pat
123123
assert captions_result.next_cursor is None
124124
# Collect all the filename x caption pairs and assert they are as expected
125125
assert {
126-
(c.sample.file_name, c.text)
126+
(c.sample.sample_id, c.text)
127127
for c in captions_result.captions
128-
if isinstance(c.sample, ImageTable)
128+
if isinstance(c.sample, SampleTable)
129129
} == {
130-
("image1.jpg", "Caption 1 of image 1"),
131-
("image1.jpg", "Caption 2 of image 1"),
132-
("image2.jpg", "Caption 1 of image 2"),
130+
(samples[0].sample_id, "Caption 1 of image 1"),
131+
(samples[0].sample_id, "Caption 2 of image 1"),
132+
(samples[1].sample_id, "Caption 1 of image 2"),
133133
}
134134

135135

lightly_studio/tests/core/test_dataset__coco_caption.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from PIL import Image
88

99
from lightly_studio import Dataset
10-
from lightly_studio.models.image import ImageTable
1110
from lightly_studio.resolvers import caption_resolver
1211

1312

@@ -32,7 +31,6 @@ def test_add_samples_from_coco_caption__details_valid(
3231
)
3332
assert dataset.name == "test_dataset"
3433
samples = dataset._inner.get_samples()
35-
samples = sorted(samples, key=lambda sample: sample.file_path_abs)
3634

3735
assert len(samples) == 2
3836
assert {s.file_name for s in samples} == {"image1.jpg", "image2.jpg"}
@@ -45,11 +43,11 @@ def test_add_samples_from_coco_caption__details_valid(
4543
assert len(captions_result.captions) == 3
4644
assert captions_result.total_count == 3
4745
assert captions_result.next_cursor is None
46+
4847
# Collect all the filename x caption pairs and assert they are as expected
48+
sample_id_to_file_path = {s.sample.sample_id: s.file_name for s in samples}
4949
assert {
50-
(c.sample.file_name, c.text)
51-
for c in captions_result.captions
52-
if isinstance(c.sample, ImageTable)
50+
(sample_id_to_file_path[c.sample.sample_id], c.text) for c in captions_result.captions
5351
} == {
5452
("image1.jpg", "Caption 1 of image 1"),
5553
("image1.jpg", "Caption 2 of image 1"),

0 commit comments

Comments
 (0)