Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions lightly_studio/src/lightly_studio/api/routes/api/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
DatasetView,
DatasetViewWithCount,
)
from lightly_studio.resolvers import dataset_resolver
from lightly_studio.resolvers import dataset_resolver, datasets_resolver
from lightly_studio.resolvers.dataset_resolver import (
ExportFilter,
)
Expand Down Expand Up @@ -75,7 +75,7 @@ def read_dataset(
],
) -> DatasetViewWithCount:
"""Retrieve a single dataset from the database."""
return dataset_resolver.get_dataset_details(session=session, dataset=dataset)
return datasets_resolver.get_dataset_details(session=session, dataset=dataset)


@dataset_router.put("/datasets/{dataset_id}")
Expand Down
19 changes: 1 addition & 18 deletions lightly_studio/src/lightly_studio/resolvers/dataset_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sqlmodel.sql.expression import SelectOfScalar

from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable
from lightly_studio.models.dataset import DatasetCreate, DatasetTable, DatasetViewWithCount
from lightly_studio.models.dataset import DatasetCreate, DatasetTable
from lightly_studio.models.sample import SampleTable
from lightly_studio.models.tag import TagTable

Expand Down Expand Up @@ -75,23 +75,6 @@ def get_by_name(session: Session, name: str) -> DatasetTable | None:
return session.exec(select(DatasetTable).where(DatasetTable.name == name)).one_or_none()


def get_dataset_details(session: Session, dataset: DatasetTable) -> DatasetViewWithCount:
"""Convert a DatasetTable to DatasetViewWithCount with computed sample count."""
sample_count = (
session.exec(
select(func.count("*")).where(SampleTable.dataset_id == dataset.dataset_id)
).one()
or 0
)
return DatasetViewWithCount(
dataset_id=dataset.dataset_id,
name=dataset.name,
created_at=dataset.created_at,
updated_at=dataset.updated_at,
total_sample_count=sample_count,
)


def update(session: Session, dataset_id: UUID, dataset_data: DatasetCreate) -> DatasetTable:
"""Update an existing dataset."""
dataset = get_by_id(session=session, dataset_id=dataset_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Resolvers for database operations."""

from lightly_studio.resolvers.datasets_resolver.get_dataset_details import (
get_dataset_details,
)

__all__ = [
"get_dataset_details",
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Handler for database operations related to datasets."""

from __future__ import annotations

from sqlmodel import Session, func, select

from lightly_studio.models.dataset import DatasetTable, DatasetViewWithCount
from lightly_studio.models.sample import SampleTable


def get_dataset_details(session: Session, dataset: DatasetTable) -> DatasetViewWithCount:
"""Convert a DatasetTable to DatasetViewWithCount with computed sample count."""
sample_count = (
session.exec(
select(func.count("*")).where(SampleTable.dataset_id == dataset.dataset_id)
).one()
or 0
)
return DatasetViewWithCount(
dataset_id=dataset.dataset_id,
name=dataset.name,
created_at=dataset.created_at,
updated_at=dataset.updated_at,
total_sample_count=sample_count,
)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""Tests for datasets_resolver - get_dataset_details functionality."""

from __future__ import annotations

from sqlmodel import Session

from lightly_studio.resolvers import datasets_resolver
from tests.helpers_resolvers import create_dataset, create_sample


def test_get_dataset_details(
db_session: Session,
) -> None:
"""Test that get_dataset_details returns correct sample count."""
dataset = create_dataset(session=db_session, dataset_name="test_dataset")

create_sample(
session=db_session,
dataset_id=dataset.dataset_id,
file_path_abs="/path/to/image1.jpg",
)
create_sample(
session=db_session,
dataset_id=dataset.dataset_id,
file_path_abs="/path/to/image2.jpg",
)
create_sample(
session=db_session,
dataset_id=dataset.dataset_id,
file_path_abs="/path/to/image3.jpg",
)

result = datasets_resolver.get_dataset_details(session=db_session, dataset=dataset)

assert result.dataset_id == dataset.dataset_id
assert result.name == dataset.name
assert result.created_at == dataset.created_at
assert result.updated_at == dataset.updated_at
assert result.total_sample_count == 3


def test_get_dataset_details__empty_dataset(
db_session: Session,
) -> None:
"""Test that get_dataset_details returns zero for empty dataset."""
dataset = create_dataset(session=db_session, dataset_name="empty_dataset")

result = datasets_resolver.get_dataset_details(session=db_session, dataset=dataset)

assert result.total_sample_count == 0
assert result.dataset_id == dataset.dataset_id
assert result.name == dataset.name