diff --git a/lightly_studio/src/lightly_studio/api/routes/api/dataset.py b/lightly_studio/src/lightly_studio/api/routes/api/dataset.py index 7c7eb782..922e4d4d 100644 --- a/lightly_studio/src/lightly_studio/api/routes/api/dataset.py +++ b/lightly_studio/src/lightly_studio/api/routes/api/dataset.py @@ -21,7 +21,7 @@ DatasetView, DatasetViewWithCount, ) -from lightly_studio.resolvers import dataset_resolver +from lightly_studio.resolvers import dataset_resolver, datasets_resolver from lightly_studio.resolvers.dataset_resolver import ( ExportFilter, ) @@ -75,7 +75,7 @@ def read_dataset( ], ) -> DatasetViewWithCount: """Retrieve a single dataset from the database.""" - return dataset_resolver.get_dataset_details(session=session, dataset=dataset) + return datasets_resolver.get_dataset_details(session=session, dataset=dataset) @dataset_router.put("/datasets/{dataset_id}") diff --git a/lightly_studio/src/lightly_studio/resolvers/dataset_resolver.py b/lightly_studio/src/lightly_studio/resolvers/dataset_resolver.py index 1bf9bd18..8dd5d096 100644 --- a/lightly_studio/src/lightly_studio/resolvers/dataset_resolver.py +++ b/lightly_studio/src/lightly_studio/resolvers/dataset_resolver.py @@ -10,7 +10,7 @@ from sqlmodel.sql.expression import SelectOfScalar from lightly_studio.models.annotation.annotation_base import AnnotationBaseTable -from lightly_studio.models.dataset import DatasetCreate, DatasetTable, DatasetViewWithCount +from lightly_studio.models.dataset import DatasetCreate, DatasetTable from lightly_studio.models.sample import SampleTable from lightly_studio.models.tag import TagTable @@ -75,23 +75,6 @@ def get_by_name(session: Session, name: str) -> DatasetTable | None: return session.exec(select(DatasetTable).where(DatasetTable.name == name)).one_or_none() -def get_dataset_details(session: Session, dataset: DatasetTable) -> DatasetViewWithCount: - """Convert a DatasetTable to DatasetViewWithCount with computed sample count.""" - sample_count = ( - session.exec( - select(func.count("*")).where(SampleTable.dataset_id == dataset.dataset_id) - ).one() - or 0 - ) - return DatasetViewWithCount( - dataset_id=dataset.dataset_id, - name=dataset.name, - created_at=dataset.created_at, - updated_at=dataset.updated_at, - total_sample_count=sample_count, - ) - - def update(session: Session, dataset_id: UUID, dataset_data: DatasetCreate) -> DatasetTable: """Update an existing dataset.""" dataset = get_by_id(session=session, dataset_id=dataset_id) diff --git a/lightly_studio/src/lightly_studio/resolvers/datasets_resolver/__init__.py b/lightly_studio/src/lightly_studio/resolvers/datasets_resolver/__init__.py new file mode 100644 index 00000000..acfb204d --- /dev/null +++ b/lightly_studio/src/lightly_studio/resolvers/datasets_resolver/__init__.py @@ -0,0 +1,9 @@ +"""Resolvers for database operations.""" + +from lightly_studio.resolvers.datasets_resolver.get_dataset_details import ( + get_dataset_details, +) + +__all__ = [ + "get_dataset_details", +] diff --git a/lightly_studio/src/lightly_studio/resolvers/datasets_resolver/get_dataset_details.py b/lightly_studio/src/lightly_studio/resolvers/datasets_resolver/get_dataset_details.py new file mode 100644 index 00000000..2177e7d1 --- /dev/null +++ b/lightly_studio/src/lightly_studio/resolvers/datasets_resolver/get_dataset_details.py @@ -0,0 +1,25 @@ +"""Handler for database operations related to datasets.""" + +from __future__ import annotations + +from sqlmodel import Session, func, select + +from lightly_studio.models.dataset import DatasetTable, DatasetViewWithCount +from lightly_studio.models.sample import SampleTable + + +def get_dataset_details(session: Session, dataset: DatasetTable) -> DatasetViewWithCount: + """Convert a DatasetTable to DatasetViewWithCount with computed sample count.""" + sample_count = ( + session.exec( + select(func.count("*")).where(SampleTable.dataset_id == dataset.dataset_id) + ).one() + or 0 + ) + return DatasetViewWithCount( + dataset_id=dataset.dataset_id, + name=dataset.name, + created_at=dataset.created_at, + updated_at=dataset.updated_at, + total_sample_count=sample_count, + ) diff --git a/lightly_studio/tests/resolvers/datasets_resolver/__init__.py b/lightly_studio/tests/resolvers/datasets_resolver/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightly_studio/tests/resolvers/datasets_resolver/test_get_dataset_details.py b/lightly_studio/tests/resolvers/datasets_resolver/test_get_dataset_details.py new file mode 100644 index 00000000..d6e09113 --- /dev/null +++ b/lightly_studio/tests/resolvers/datasets_resolver/test_get_dataset_details.py @@ -0,0 +1,52 @@ +"""Tests for datasets_resolver - get_dataset_details functionality.""" + +from __future__ import annotations + +from sqlmodel import Session + +from lightly_studio.resolvers import datasets_resolver +from tests.helpers_resolvers import create_dataset, create_sample + + +def test_get_dataset_details( + db_session: Session, +) -> None: + """Test that get_dataset_details returns correct sample count.""" + dataset = create_dataset(session=db_session, dataset_name="test_dataset") + + create_sample( + session=db_session, + dataset_id=dataset.dataset_id, + file_path_abs="/path/to/image1.jpg", + ) + create_sample( + session=db_session, + dataset_id=dataset.dataset_id, + file_path_abs="/path/to/image2.jpg", + ) + create_sample( + session=db_session, + dataset_id=dataset.dataset_id, + file_path_abs="/path/to/image3.jpg", + ) + + result = datasets_resolver.get_dataset_details(session=db_session, dataset=dataset) + + assert result.dataset_id == dataset.dataset_id + assert result.name == dataset.name + assert result.created_at == dataset.created_at + assert result.updated_at == dataset.updated_at + assert result.total_sample_count == 3 + + +def test_get_dataset_details__empty_dataset( + db_session: Session, +) -> None: + """Test that get_dataset_details returns zero for empty dataset.""" + dataset = create_dataset(session=db_session, dataset_name="empty_dataset") + + result = datasets_resolver.get_dataset_details(session=db_session, dataset=dataset) + + assert result.total_sample_count == 0 + assert result.dataset_id == dataset.dataset_id + assert result.name == dataset.name