Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions docs/coding-guidelines/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ from lightly_studio.api.dependencies.samples import get_sample_service
@router.post("/samples")
def create_sample(
service: SampleService = Depends(get_sample_service),
sample: Annotated[SampleCreate, Body()]
) -> SampleView:
sample: Annotated[ImageCreate, Body()]
) -> ImageView:
return service.create_sample(sample)
```

Expand All @@ -120,7 +120,7 @@ We cover the most commonly used patterns below:
@router.get("/samples/{sample_id}")
def get_sample(
sample_id: Annotated[UUID, Path(title="Sample Id")],
) -> SampleView:
) -> ImageView:
...
```

Expand All @@ -133,7 +133,7 @@ class QueryParams(BaseModel):
@router.get("/samples")
def get_samples(
query_params: Annotated[QueryParams, Query()]
) -> list[SampleView]:
) -> list[ImageView]:
...
```

Expand All @@ -142,8 +142,8 @@ def get_samples(
```python
@router.post("/samples")
def create_sample(
sample: Annotated[SampleCreate, Body()]
) -> SampleView:
sample: Annotated[ImageCreate, Body()]
) -> ImageView:
...
```

Expand Down Expand Up @@ -177,7 +177,7 @@ Follow these conventions for API endpoints:

```python
from fastapi import APIRouter, Depends
from lightly_studio.api.models.samples import SampleCreate, SampleView
from lightly_studio.api.models.samples import ImageCreate, ImageView
from lightly_studio.services.samples import SampleService

# src/lightly_studio/api/routes/samples.py
Expand All @@ -187,16 +187,16 @@ router = APIRouter()
@router.post("/samples")
def create_sample(
service: SampleService = Depends(get_sample_service),
sample: Annotated[SampleCreate, Body()],
) -> SampleView:
sample: Annotated[ImageCreate, Body()],
) -> ImageView:
"""Create a new sample."""
return service.create_sample(sample)

@router.get("/samples/{sample_id}")
def get_sample_by_id(
service: SampleService = Depends(get_sample_service),
sample_id: Annotated[UUID, Path()],
) -> SampleView:
) -> ImageView:
"""Get a sample by ID."""
return service.get_sample_by_id(sample_id)
```
Expand All @@ -213,19 +213,19 @@ The service layer encapsulates the business logic of the application. It process
# src/lightly_studio/services/samples.py
from sqlmodel import Session
from lightly_studio.resolvers import samples_resolver
from lightly_studio.api.v1.models.samples import SampleCreate, SampleView
from lightly_studio.api.v1.models.samples import ImageCreate, ImageView

class SampleService:
"""Service class for sample operations."""

def __init__(self, session: Session) -> None:
self.session = session

def create_sample(self, sample: SampleCreate) -> SampleView:
def create_sample(self, sample: ImageCreate) -> ImageView:
"""Create a new sample."""
return samples_resolver.create_sample(session=self.session, sample_create=sample)

def create_sample_with_metadata(self, sample: SampleCreate) -> SampleView:
def create_sample_with_metadata(self, sample: ImageCreate) -> ImageView:
"""Create a new sample with computed metadata."""
# Put business logic here
metadata = _compute_metadata_for_sample(sample)
Expand All @@ -237,7 +237,7 @@ class SampleService:
)
return new_sample, new_metadata

def get_sample_by_id(self, sample_id: UUID) -> SampleView:
def get_sample_by_id(self, sample_id: UUID) -> ImageView:
"""Get a sample by ID."""
return samples_resolver.get_sample_by_id(session=self.session, sample_id=sample_id)
```
Expand All @@ -258,9 +258,9 @@ The resolvers layer interacts with the database. It handles database operations
from __future__ import annotations

from sqlmodel import Session, select
from lightly_studio.models.sample import Sample, SampleCreate
from lightly_studio.models.sample import Sample, ImageCreate

def create_sample(session: Session, sample_create: SampleCreate) -> Sample:
def create_sample(session: Session, sample_create: ImageCreate) -> Sample:
"""Create a new sample in the database."""
db_sample = Sample.model_validate(sample_create)
session.add(db_sample)
Expand All @@ -274,10 +274,10 @@ Models are used to define the structure of data in the application. They are use

We store models in the `src/lightly_studio/models` directory. The models are organized into different files based on their functionality. For example, we have a `sample.py` file to define the models for samples such as:
- `SampleBase` - the base model describing common fields for all sample-related models.
- `SampleTable` - the main model for samples defining table structure and relationships.
- `SampleCreate` - a model used when creating a new sample, which may contain only a subset of fields.
- `ImageTable` - the main model for samples defining table structure and relationships.
- `ImageCreate` - a model used when creating a new sample, which may contain only a subset of fields.
- `SampleUpdate` - a model used for updating existing samples, which may contain only the fields that can be updated.
- `SampleView` - a model used for viewing samples, which may contain additional fields like timestamps or relationships.
- `ImageView` - a model used for viewing samples, which may contain additional fields like timestamps or relationships.
- `SampleLink` - a model used for linking samples, which may contain fields like `sample_id`, `linked_sample_id`, and `relationship_type`. Typically only needed for many-to-many relationships.

Usually only "Base", "Table", "Create", "Update", and "View" models are needed for each entity.
Expand Down Expand Up @@ -321,13 +321,13 @@ class Sample(SampleBase, table=True):
embeddings: list["SampleEmbedding"] = Relationship(back_populates="sample")

# This model is used to validate and describe the input when creating a new sample.
class SampleCreate(SampleBase):
class ImageCreate(SampleBase):
"""Sample create model."""
pass

# This model is used to validate and describe the output, e.g. in the API response.
# It may contain reduced fields compared to the Sample model to optimize response size.
class SampleView(SampleBase):
class ImageView(SampleBase):
"""Sample view model."""
sample_id: UUID

Expand Down Expand Up @@ -395,7 +395,7 @@ from lightly_studio.services.exceptions import ServiceError
from lightly_studio.services.errors import STATUS_NOT_FOUND

class SampleService:
def get_sample_by_id(self, sample_id: UUID) -> SampleView:
def get_sample_by_id(self, sample_id: UUID) -> ImageView:
"""Get a sample by ID."""
statement = select(Sample).where(Sample.sample_id == sample_id)
sample = self.session.exec(statement).first()
Expand All @@ -404,7 +404,7 @@ class SampleService:
code=STATUS_NOT_FOUND,
message=f"Sample with id {sample_id} not found"
)
return SampleView.model_validate(sample)
return ImageView.model_validate(sample)

# src/lightly_studio/services/exceptions.py
"""Custom exceptions for the lightly_studio service layer."""
Expand Down
2 changes: 1 addition & 1 deletion lightly_studio/src/lightly_studio/api/db_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SampleMetadataTable, # noqa: F401, required for SQLModel to work properly
)
from lightly_studio.models.sample import (
SampleTable, # noqa: F401, required for SQLModel to work properly
ImageTable, # noqa: F401, required for SQLModel to work properly
)
from lightly_studio.models.sample_embedding import (
SampleEmbeddingTable, # noqa: F401, required for SQLModel to work properly
Expand Down
24 changes: 12 additions & 12 deletions lightly_studio/src/lightly_studio/api/routes/api/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
from lightly_studio.db_manager import SessionDep
from lightly_studio.models.dataset import DatasetTable
from lightly_studio.models.sample import (
SampleCreate,
SampleTable,
SampleView,
SampleViewsWithCount,
ImageCreate,
ImageTable,
ImageView,
ImageViewsWithCount,
)
from lightly_studio.resolvers import (
sample_resolver,
Expand All @@ -34,11 +34,11 @@
samples_router = APIRouter(prefix="/datasets/{dataset_id}", tags=["samples"])


@samples_router.post("/samples", response_model=SampleView)
@samples_router.post("/samples", response_model=ImageView)
def create_sample(
session: SessionDep,
input_sample: SampleCreate,
) -> SampleTable:
input_sample: ImageCreate,
) -> ImageTable:
"""Create a new sample in the database."""
return sample_resolver.create(session=session, sample=input_sample)

Expand All @@ -54,7 +54,7 @@ class ReadSamplesRequest(BaseModel):
)


@samples_router.post("/samples/list", response_model=SampleViewsWithCount)
@samples_router.post("/samples/list", response_model=ImageViewsWithCount)
def read_samples(
session: SessionDep,
dataset_id: Annotated[UUID, Path(title="Dataset Id")],
Expand Down Expand Up @@ -98,12 +98,12 @@ def get_sample_dimensions(
)


@samples_router.get("/samples/{sample_id}", response_model=SampleView)
@samples_router.get("/samples/{sample_id}", response_model=ImageView)
def read_sample(
session: SessionDep,
dataset_id: Annotated[UUID, Path(title="Dataset Id", description="The ID of the dataset")],
sample_id: Annotated[UUID, Path(title="Sample Id")],
) -> SampleTable:
) -> ImageTable:
"""Retrieve a single sample from the database."""
sample = sample_resolver.get_by_id(session=session, dataset_id=dataset_id, sample_id=sample_id)
if not sample:
Expand All @@ -115,8 +115,8 @@ def read_sample(
def update_sample(
session: SessionDep,
sample_id: Annotated[UUID, Path(title="Sample Id")],
sample_input: SampleCreate,
) -> SampleTable:
sample_input: ImageCreate,
) -> ImageTable:
"""Update an existing sample in the database."""
sample = sample_resolver.update(session=session, sample_id=sample_id, sample_data=sample_input)
if not sample:
Expand Down
2 changes: 1 addition & 1 deletion lightly_studio/src/lightly_studio/api/routes/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async def serve_image_by_sample_id(
HTTPException: If the sample is not found or the file is not accessible.
"""
# Retrieve the sample from the database.
sample_record = session.get(sample.SampleTable, sample_id)
sample_record = session.get(sample.ImageTable, sample_id)
if not sample_record:
raise HTTPException(
status_code=status.HTTP_STATUS_NOT_FOUND,
Expand Down
24 changes: 12 additions & 12 deletions lightly_studio/src/lightly_studio/core/add_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from lightly_studio.models.annotation.annotation_base import AnnotationCreate
from lightly_studio.models.annotation_label import AnnotationLabelCreate
from lightly_studio.models.caption import CaptionCreate
from lightly_studio.models.sample import SampleCreate, SampleTable
from lightly_studio.models.sample import ImageCreate, ImageTable
from lightly_studio.resolvers import (
annotation_label_resolver,
annotation_resolver,
Expand Down Expand Up @@ -82,7 +82,7 @@ def load_into_dataset_from_paths(
Returns:
A list of UUIDs of the created samples.
"""
samples_to_create: list[SampleCreate] = []
samples_to_create: list[ImageCreate] = []
created_sample_ids: list[UUID] = []

logging_context = _LoadingLoggingContext(
Expand All @@ -105,7 +105,7 @@ def load_into_dataset_from_paths(
except (FileNotFoundError, PIL.UnidentifiedImageError, OSError):
continue

sample = SampleCreate(
sample = ImageCreate(
file_name=Path(image_path).name,
file_path_abs=image_path,
width=width,
Expand Down Expand Up @@ -163,15 +163,15 @@ def load_into_dataset_from_labelformat(
label_map = _create_label_map(session=session, input_labels=input_labels)

annotations_to_create: list[AnnotationCreate] = []
samples_to_create: list[SampleCreate] = []
samples_to_create: list[ImageCreate] = []
created_sample_ids: list[UUID] = []
image_path_to_anno_data: dict[str, ImageInstanceSegmentation | ImageObjectDetection] = {}

for image_data in tqdm(input_labels.get_labels(), desc="Processing images", unit=" images"):
image: Image = image_data.image # type: ignore[attr-defined]

typed_image_data: ImageInstanceSegmentation | ImageObjectDetection = image_data # type: ignore[assignment]
sample = SampleCreate(
sample = ImageCreate(
file_name=str(image.filename),
file_path_abs=str(images_path / image.filename),
width=image.width,
Expand Down Expand Up @@ -266,7 +266,7 @@ def load_into_dataset_from_coco_captions(
)

captions_to_create: list[CaptionCreate] = []
samples_to_create: list[SampleCreate] = []
samples_to_create: list[ImageCreate] = []
created_sample_ids: list[UUID] = []
image_path_to_captions: dict[str, list[str]] = {}

Expand All @@ -279,7 +279,7 @@ def load_into_dataset_from_coco_captions(

width = image_info["width"] if isinstance(image_info["width"], int) else 0
height = image_info["height"] if isinstance(image_info["height"], int) else 0
sample = SampleCreate(
sample = ImageCreate(
file_name=file_name_raw,
file_path_abs=str(images_path / file_name_raw),
width=width,
Expand Down Expand Up @@ -345,16 +345,16 @@ def _log_loading_results(


def _create_batch_samples(
session: Session, samples: list[SampleCreate]
) -> tuple[list[SampleTable], list[str]]:
session: Session, samples: list[ImageCreate]
) -> tuple[list[ImageTable], list[str]]:
"""Create the batch samples.

Args:
session: The database session.
samples: The samples to create.

Returns:
created_samples: A list of created SampleTable objects,
created_samples: A list of created ImageTable objects,
existing_file_paths: A list of file paths that already existed in the database,
"""
file_paths_abs_mapping = {sample.file_path_abs: sample for sample in samples}
Expand Down Expand Up @@ -449,7 +449,7 @@ def _process_instance_segmentation_annotations(

def _process_batch_annotations( # noqa: PLR0913
session: Session,
stored_samples: list[SampleTable],
stored_samples: list[ImageTable],
image_path_to_anno_data: dict[str, ImageInstanceSegmentation | ImageObjectDetection],
dataset_id: UUID,
label_map: dict[int, UUID],
Expand Down Expand Up @@ -486,7 +486,7 @@ def _process_batch_annotations( # noqa: PLR0913
def _process_batch_captions(
session: Session,
dataset_id: UUID,
stored_samples: list[SampleTable],
stored_samples: list[ImageTable],
image_path_to_captions: dict[str, list[str]],
captions_to_create: list[CaptionCreate],
) -> None:
Expand Down
6 changes: 3 additions & 3 deletions lightly_studio/src/lightly_studio/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
AnnotationType,
)
from lightly_studio.models.dataset import DatasetCreate, DatasetTable
from lightly_studio.models.sample import SampleTable
from lightly_studio.models.sample import ImageTable
from lightly_studio.resolvers import (
dataset_resolver,
embedding_model_resolver,
Expand Down Expand Up @@ -143,7 +143,7 @@ def load_or_create(name: str | None = None) -> Dataset:
def __iter__(self) -> Iterator[Sample]:
"""Iterate over samples in the dataset."""
for sample in self.session.exec(
select(SampleTable).where(SampleTable.dataset_id == self.dataset_id)
select(ImageTable).where(ImageTable.dataset_id == self.dataset_id)
):
yield Sample(inner=sample)

Expand All @@ -154,7 +154,7 @@ def get_sample(self, sample_id: UUID) -> Sample:
sample_id: The UUID of the sample to retrieve.

Returns:
A single SampleTable object.
A single ImageTable object.

Raises:
IndexError: If no sample is found with the given sample_id.
Expand Down
Loading