Skip to content

Commit 6116d8b

Browse files
Multimodal: Rename pydantic classes from Sample* to Image* (#8)
1 parent ac4e299 commit 6116d8b

File tree

62 files changed

+1625
-413
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+1625
-413
lines changed

docs/coding-guidelines/backend.md

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,8 @@ from lightly_studio.api.dependencies.samples import get_sample_service
104104
@router.post("/samples")
105105
def create_sample(
106106
service: SampleService = Depends(get_sample_service),
107-
sample: Annotated[SampleCreate, Body()]
108-
) -> SampleView:
107+
sample: Annotated[ImageCreate, Body()]
108+
) -> ImageView:
109109
return service.create_sample(sample)
110110
```
111111

@@ -120,7 +120,7 @@ We cover the most commonly used patterns below:
120120
@router.get("/samples/{sample_id}")
121121
def get_sample(
122122
sample_id: Annotated[UUID, Path(title="Sample Id")],
123-
) -> SampleView:
123+
) -> ImageView:
124124
...
125125
```
126126

@@ -133,7 +133,7 @@ class QueryParams(BaseModel):
133133
@router.get("/samples")
134134
def get_samples(
135135
query_params: Annotated[QueryParams, Query()]
136-
) -> list[SampleView]:
136+
) -> list[ImageView]:
137137
...
138138
```
139139

@@ -142,8 +142,8 @@ def get_samples(
142142
```python
143143
@router.post("/samples")
144144
def create_sample(
145-
sample: Annotated[SampleCreate, Body()]
146-
) -> SampleView:
145+
sample: Annotated[ImageCreate, Body()]
146+
) -> ImageView:
147147
...
148148
```
149149

@@ -177,7 +177,7 @@ Follow these conventions for API endpoints:
177177

178178
```python
179179
from fastapi import APIRouter, Depends
180-
from lightly_studio.api.models.samples import SampleCreate, SampleView
180+
from lightly_studio.api.models.samples import ImageCreate, ImageView
181181
from lightly_studio.services.samples import SampleService
182182

183183
# src/lightly_studio/api/routes/samples.py
@@ -187,16 +187,16 @@ router = APIRouter()
187187
@router.post("/samples")
188188
def create_sample(
189189
service: SampleService = Depends(get_sample_service),
190-
sample: Annotated[SampleCreate, Body()],
191-
) -> SampleView:
190+
sample: Annotated[ImageCreate, Body()],
191+
) -> ImageView:
192192
"""Create a new sample."""
193193
return service.create_sample(sample)
194194

195195
@router.get("/samples/{sample_id}")
196196
def get_sample_by_id(
197197
service: SampleService = Depends(get_sample_service),
198198
sample_id: Annotated[UUID, Path()],
199-
) -> SampleView:
199+
) -> ImageView:
200200
"""Get a sample by ID."""
201201
return service.get_sample_by_id(sample_id)
202202
```
@@ -213,19 +213,19 @@ The service layer encapsulates the business logic of the application. It process
213213
# src/lightly_studio/services/samples.py
214214
from sqlmodel import Session
215215
from lightly_studio.resolvers import samples_resolver
216-
from lightly_studio.api.v1.models.samples import SampleCreate, SampleView
216+
from lightly_studio.api.v1.models.samples import ImageCreate, ImageView
217217

218218
class SampleService:
219219
"""Service class for sample operations."""
220220

221221
def __init__(self, session: Session) -> None:
222222
self.session = session
223223

224-
def create_sample(self, sample: SampleCreate) -> SampleView:
224+
def create_sample(self, sample: ImageCreate) -> ImageView:
225225
"""Create a new sample."""
226226
return samples_resolver.create_sample(session=self.session, sample_create=sample)
227227

228-
def create_sample_with_metadata(self, sample: SampleCreate) -> SampleView:
228+
def create_sample_with_metadata(self, sample: ImageCreate) -> ImageView:
229229
"""Create a new sample with computed metadata."""
230230
# Put business logic here
231231
metadata = _compute_metadata_for_sample(sample)
@@ -237,7 +237,7 @@ class SampleService:
237237
)
238238
return new_sample, new_metadata
239239

240-
def get_sample_by_id(self, sample_id: UUID) -> SampleView:
240+
def get_sample_by_id(self, sample_id: UUID) -> ImageView:
241241
"""Get a sample by ID."""
242242
return samples_resolver.get_sample_by_id(session=self.session, sample_id=sample_id)
243243
```
@@ -258,9 +258,9 @@ The resolvers layer interacts with the database. It handles database operations
258258
from __future__ import annotations
259259

260260
from sqlmodel import Session, select
261-
from lightly_studio.models.sample import Sample, SampleCreate
261+
from lightly_studio.models.sample import Sample, ImageCreate
262262

263-
def create_sample(session: Session, sample_create: SampleCreate) -> Sample:
263+
def create_sample(session: Session, sample_create: ImageCreate) -> Sample:
264264
"""Create a new sample in the database."""
265265
db_sample = Sample.model_validate(sample_create)
266266
session.add(db_sample)
@@ -274,10 +274,10 @@ Models are used to define the structure of data in the application. They are use
274274

275275
We store models in the `src/lightly_studio/models` directory. The models are organized into different files based on their functionality. For example, we have a `sample.py` file to define the models for samples such as:
276276
- `SampleBase` - the base model describing common fields for all sample-related models.
277-
- `SampleTable` - the main model for samples defining table structure and relationships.
278-
- `SampleCreate` - a model used when creating a new sample, which may contain only a subset of fields.
277+
- `ImageTable` - the main model for samples defining table structure and relationships.
278+
- `ImageCreate` - a model used when creating a new sample, which may contain only a subset of fields.
279279
- `SampleUpdate` - a model used for updating existing samples, which may contain only the fields that can be updated.
280-
- `SampleView` - a model used for viewing samples, which may contain additional fields like timestamps or relationships.
280+
- `ImageView` - a model used for viewing samples, which may contain additional fields like timestamps or relationships.
281281
- `SampleLink` - a model used for linking samples, which may contain fields like `sample_id`, `linked_sample_id`, and `relationship_type`. Typically only needed for many-to-many relationships.
282282

283283
Usually only "Base", "Table", "Create", "Update", and "View" models are needed for each entity.
@@ -321,13 +321,13 @@ class Sample(SampleBase, table=True):
321321
embeddings: list["SampleEmbedding"] = Relationship(back_populates="sample")
322322

323323
# This model is used to validate and describe the input when creating a new sample.
324-
class SampleCreate(SampleBase):
324+
class ImageCreate(SampleBase):
325325
"""Sample create model."""
326326
pass
327327

328328
# This model is used to validate and describe the output, e.g. in the API response.
329329
# It may contain reduced fields compared to the Sample model to optimize response size.
330-
class SampleView(SampleBase):
330+
class ImageView(SampleBase):
331331
"""Sample view model."""
332332
sample_id: UUID
333333

@@ -395,7 +395,7 @@ from lightly_studio.services.exceptions import ServiceError
395395
from lightly_studio.services.errors import STATUS_NOT_FOUND
396396

397397
class SampleService:
398-
def get_sample_by_id(self, sample_id: UUID) -> SampleView:
398+
def get_sample_by_id(self, sample_id: UUID) -> ImageView:
399399
"""Get a sample by ID."""
400400
statement = select(Sample).where(Sample.sample_id == sample_id)
401401
sample = self.session.exec(statement).first()
@@ -404,7 +404,7 @@ class SampleService:
404404
code=STATUS_NOT_FOUND,
405405
message=f"Sample with id {sample_id} not found"
406406
)
407-
return SampleView.model_validate(sample)
407+
return ImageView.model_validate(sample)
408408

409409
# src/lightly_studio/services/exceptions.py
410410
"""Custom exceptions for the lightly_studio service layer."""

lightly_studio/src/lightly_studio/api/db_tables.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
SampleMetadataTable, # noqa: F401, required for SQLModel to work properly
1717
)
1818
from lightly_studio.models.sample import (
19-
SampleTable, # noqa: F401, required for SQLModel to work properly
19+
ImageTable, # noqa: F401, required for SQLModel to work properly
2020
)
2121
from lightly_studio.models.sample_embedding import (
2222
SampleEmbeddingTable, # noqa: F401, required for SQLModel to work properly

lightly_studio/src/lightly_studio/api/routes/api/sample.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from lightly_studio.db_manager import SessionDep
1818
from lightly_studio.models.dataset import DatasetTable
1919
from lightly_studio.models.sample import (
20-
SampleCreate,
21-
SampleTable,
22-
SampleView,
23-
SampleViewsWithCount,
20+
ImageCreate,
21+
ImageTable,
22+
ImageView,
23+
ImageViewsWithCount,
2424
)
2525
from lightly_studio.resolvers import (
2626
sample_resolver,
@@ -34,11 +34,11 @@
3434
samples_router = APIRouter(prefix="/datasets/{dataset_id}", tags=["samples"])
3535

3636

37-
@samples_router.post("/samples", response_model=SampleView)
37+
@samples_router.post("/samples", response_model=ImageView)
3838
def create_sample(
3939
session: SessionDep,
40-
input_sample: SampleCreate,
41-
) -> SampleTable:
40+
input_sample: ImageCreate,
41+
) -> ImageTable:
4242
"""Create a new sample in the database."""
4343
return sample_resolver.create(session=session, sample=input_sample)
4444

@@ -54,7 +54,7 @@ class ReadSamplesRequest(BaseModel):
5454
)
5555

5656

57-
@samples_router.post("/samples/list", response_model=SampleViewsWithCount)
57+
@samples_router.post("/samples/list", response_model=ImageViewsWithCount)
5858
def read_samples(
5959
session: SessionDep,
6060
dataset_id: Annotated[UUID, Path(title="Dataset Id")],
@@ -98,12 +98,12 @@ def get_sample_dimensions(
9898
)
9999

100100

101-
@samples_router.get("/samples/{sample_id}", response_model=SampleView)
101+
@samples_router.get("/samples/{sample_id}", response_model=ImageView)
102102
def read_sample(
103103
session: SessionDep,
104104
dataset_id: Annotated[UUID, Path(title="Dataset Id", description="The ID of the dataset")],
105105
sample_id: Annotated[UUID, Path(title="Sample Id")],
106-
) -> SampleTable:
106+
) -> ImageTable:
107107
"""Retrieve a single sample from the database."""
108108
sample = sample_resolver.get_by_id(session=session, dataset_id=dataset_id, sample_id=sample_id)
109109
if not sample:
@@ -115,8 +115,8 @@ def read_sample(
115115
def update_sample(
116116
session: SessionDep,
117117
sample_id: Annotated[UUID, Path(title="Sample Id")],
118-
sample_input: SampleCreate,
119-
) -> SampleTable:
118+
sample_input: ImageCreate,
119+
) -> ImageTable:
120120
"""Update an existing sample in the database."""
121121
sample = sample_resolver.update(session=session, sample_id=sample_id, sample_data=sample_input)
122122
if not sample:

lightly_studio/src/lightly_studio/api/routes/images.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ async def serve_image_by_sample_id(
3434
HTTPException: If the sample is not found or the file is not accessible.
3535
"""
3636
# Retrieve the sample from the database.
37-
sample_record = session.get(sample.SampleTable, sample_id)
37+
sample_record = session.get(sample.ImageTable, sample_id)
3838
if not sample_record:
3939
raise HTTPException(
4040
status_code=status.HTTP_STATUS_NOT_FOUND,

lightly_studio/src/lightly_studio/core/add_samples.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from lightly_studio.models.annotation.annotation_base import AnnotationCreate
3030
from lightly_studio.models.annotation_label import AnnotationLabelCreate
3131
from lightly_studio.models.caption import CaptionCreate
32-
from lightly_studio.models.sample import SampleCreate, SampleTable
32+
from lightly_studio.models.sample import ImageCreate, ImageTable
3333
from lightly_studio.resolvers import (
3434
annotation_label_resolver,
3535
annotation_resolver,
@@ -82,7 +82,7 @@ def load_into_dataset_from_paths(
8282
Returns:
8383
A list of UUIDs of the created samples.
8484
"""
85-
samples_to_create: list[SampleCreate] = []
85+
samples_to_create: list[ImageCreate] = []
8686
created_sample_ids: list[UUID] = []
8787

8888
logging_context = _LoadingLoggingContext(
@@ -105,7 +105,7 @@ def load_into_dataset_from_paths(
105105
except (FileNotFoundError, PIL.UnidentifiedImageError, OSError):
106106
continue
107107

108-
sample = SampleCreate(
108+
sample = ImageCreate(
109109
file_name=Path(image_path).name,
110110
file_path_abs=image_path,
111111
width=width,
@@ -163,15 +163,15 @@ def load_into_dataset_from_labelformat(
163163
label_map = _create_label_map(session=session, input_labels=input_labels)
164164

165165
annotations_to_create: list[AnnotationCreate] = []
166-
samples_to_create: list[SampleCreate] = []
166+
samples_to_create: list[ImageCreate] = []
167167
created_sample_ids: list[UUID] = []
168168
image_path_to_anno_data: dict[str, ImageInstanceSegmentation | ImageObjectDetection] = {}
169169

170170
for image_data in tqdm(input_labels.get_labels(), desc="Processing images", unit=" images"):
171171
image: Image = image_data.image # type: ignore[attr-defined]
172172

173173
typed_image_data: ImageInstanceSegmentation | ImageObjectDetection = image_data # type: ignore[assignment]
174-
sample = SampleCreate(
174+
sample = ImageCreate(
175175
file_name=str(image.filename),
176176
file_path_abs=str(images_path / image.filename),
177177
width=image.width,
@@ -266,7 +266,7 @@ def load_into_dataset_from_coco_captions(
266266
)
267267

268268
captions_to_create: list[CaptionCreate] = []
269-
samples_to_create: list[SampleCreate] = []
269+
samples_to_create: list[ImageCreate] = []
270270
created_sample_ids: list[UUID] = []
271271
image_path_to_captions: dict[str, list[str]] = {}
272272

@@ -279,7 +279,7 @@ def load_into_dataset_from_coco_captions(
279279

280280
width = image_info["width"] if isinstance(image_info["width"], int) else 0
281281
height = image_info["height"] if isinstance(image_info["height"], int) else 0
282-
sample = SampleCreate(
282+
sample = ImageCreate(
283283
file_name=file_name_raw,
284284
file_path_abs=str(images_path / file_name_raw),
285285
width=width,
@@ -345,16 +345,16 @@ def _log_loading_results(
345345

346346

347347
def _create_batch_samples(
348-
session: Session, samples: list[SampleCreate]
349-
) -> tuple[list[SampleTable], list[str]]:
348+
session: Session, samples: list[ImageCreate]
349+
) -> tuple[list[ImageTable], list[str]]:
350350
"""Create the batch samples.
351351
352352
Args:
353353
session: The database session.
354354
samples: The samples to create.
355355
356356
Returns:
357-
created_samples: A list of created SampleTable objects,
357+
created_samples: A list of created ImageTable objects,
358358
existing_file_paths: A list of file paths that already existed in the database,
359359
"""
360360
file_paths_abs_mapping = {sample.file_path_abs: sample for sample in samples}
@@ -449,7 +449,7 @@ def _process_instance_segmentation_annotations(
449449

450450
def _process_batch_annotations( # noqa: PLR0913
451451
session: Session,
452-
stored_samples: list[SampleTable],
452+
stored_samples: list[ImageTable],
453453
image_path_to_anno_data: dict[str, ImageInstanceSegmentation | ImageObjectDetection],
454454
dataset_id: UUID,
455455
label_map: dict[int, UUID],
@@ -486,7 +486,7 @@ def _process_batch_annotations( # noqa: PLR0913
486486
def _process_batch_captions(
487487
session: Session,
488488
dataset_id: UUID,
489-
stored_samples: list[SampleTable],
489+
stored_samples: list[ImageTable],
490490
image_path_to_captions: dict[str, list[str]],
491491
captions_to_create: list[CaptionCreate],
492492
) -> None:

lightly_studio/src/lightly_studio/core/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
AnnotationType,
3535
)
3636
from lightly_studio.models.dataset import DatasetCreate, DatasetTable
37-
from lightly_studio.models.sample import SampleTable
37+
from lightly_studio.models.sample import ImageTable
3838
from lightly_studio.resolvers import (
3939
dataset_resolver,
4040
embedding_model_resolver,
@@ -143,7 +143,7 @@ def load_or_create(name: str | None = None) -> Dataset:
143143
def __iter__(self) -> Iterator[Sample]:
144144
"""Iterate over samples in the dataset."""
145145
for sample in self.session.exec(
146-
select(SampleTable).where(SampleTable.dataset_id == self.dataset_id)
146+
select(ImageTable).where(ImageTable.dataset_id == self.dataset_id)
147147
):
148148
yield Sample(inner=sample)
149149

@@ -154,7 +154,7 @@ def get_sample(self, sample_id: UUID) -> Sample:
154154
sample_id: The UUID of the sample to retrieve.
155155
156156
Returns:
157-
A single SampleTable object.
157+
A single ImageTable object.
158158
159159
Raises:
160160
IndexError: If no sample is found with the given sample_id.

0 commit comments

Comments
 (0)