Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class AnnotationBaseTable(SQLModel, table=True):

annotation_id: UUID = Field(default_factory=uuid4, primary_key=True)
annotation_type: AnnotationType
annotation_label_id: UUID = Field(foreign_key="annotation_labels.annotation_label_id")
annotation_label_id: UUID = Field(foreign_key="annotation_label.annotation_label_id")

confidence: Optional[float] = None
dataset_id: UUID = Field(foreign_key="datasets.dataset_id")
sample_id: UUID = Field(foreign_key="samples.sample_id")
dataset_id: UUID = Field(foreign_key="dataset.dataset_id")
sample_id: UUID = Field(foreign_key="sample.sample_id")

annotation_label: Mapped["AnnotationLabelTable"] = Relationship(
sa_relationship_kwargs={"lazy": "select"},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
class InstanceSegmentationAnnotationTable(SQLModel, table=True):
"""Database table model for instance segmentation annotations."""

__tablename__ = "instance_segmentation_annotations"
__tablename__ = "instance_segmentation_annotation"

annotation_id: UUID = Field(
default_factory=uuid4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ class AnnotationTagLinkTable(SQLModel, table=True):
foreign_key="annotation_base.annotation_id",
primary_key=True,
)
tag_id: Optional[UUID] = Field(default=None, foreign_key="tags.tag_id", primary_key=True)
tag_id: Optional[UUID] = Field(default=None, foreign_key="tag.tag_id", primary_key=True)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class ObjectDetectionAnnotationTable(SQLModel, table=True):
"""Database table model for object detection annotations."""

__tablename__ = "object_detection_annotations"
__tablename__ = "object_detection_annotation"

annotation_id: UUID = Field(
default_factory=uuid4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
class SemanticSegmentationAnnotationTable(SQLModel, table=True):
"""Model used to define semantic segmentation annotation table."""

__tablename__ = "semantic_segmentation_annotations"
__tablename__ = "semantic_segmentation_annotation"

annotation_id: UUID = Field(
default_factory=uuid4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class AnnotationLabelView(AnnotationLabelBase):
class AnnotationLabelTable(AnnotationLabelBase, table=True):
"""This class defines the AnnotationLabel model."""

__tablename__ = "annotation_labels"
__tablename__ = "annotation_label"

annotation_label_id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: str = Field(
Expand Down
4 changes: 2 additions & 2 deletions lightly_studio/src/lightly_studio/models/caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class CaptionTable(SQLModel, table=True):
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)

caption_id: UUID = Field(default_factory=uuid4, primary_key=True)
dataset_id: UUID = Field(foreign_key="datasets.dataset_id")
sample_id: UUID = Field(foreign_key="samples.sample_id")
dataset_id: UUID = Field(foreign_key="dataset.dataset_id")
sample_id: UUID = Field(foreign_key="sample.sample_id")

sample: Mapped[Optional["SampleTable"]] = Relationship(
back_populates="captions",
Expand Down
2 changes: 1 addition & 1 deletion lightly_studio/src/lightly_studio/models/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DatasetView(DatasetBase):
class DatasetTable(DatasetBase, table=True):
"""This class defines the Dataset model."""

__tablename__ = "datasets"
__tablename__ = "dataset"
dataset_id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
updated_at: datetime = Field(
Expand Down
4 changes: 2 additions & 2 deletions lightly_studio/src/lightly_studio/models/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class EmbeddingModelBase(SQLModel):
parameter_count_in_mb: int | None = None
embedding_model_hash: str = Field(default="", sa_column=Column(CHAR(128)))
embedding_dimension: int
dataset_id: UUID = Field(default=None, foreign_key="datasets.dataset_id")
dataset_id: UUID = Field(default=None, foreign_key="dataset.dataset_id")


class EmbeddingModelCreate(EmbeddingModelBase):
Expand All @@ -25,6 +25,6 @@ class EmbeddingModelCreate(EmbeddingModelBase):
class EmbeddingModelTable(EmbeddingModelBase, table=True):
"""This class defines the EmbeddingModel model."""

__tablename__ = "embedding_models"
__tablename__ = "embedding_model"
embedding_model_id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
2 changes: 1 addition & 1 deletion lightly_studio/src/lightly_studio/models/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ class SampleMetadataTable(MetadataBase, table=True):
"""This class defines the SampleMetadataTable model."""

__tablename__ = "metadata"
sample_id: UUID = Field(foreign_key="samples.sample_id", unique=True)
sample_id: UUID = Field(foreign_key="sample.sample_id", unique=True)

sample: SampleTable = Relationship(back_populates="metadata_dict")

Expand Down
8 changes: 4 additions & 4 deletions lightly_studio/src/lightly_studio/models/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class SampleBase(SQLModel):
height: int

"""The dataset ID to which the sample belongs."""
dataset_id: UUID = Field(default=None, foreign_key="datasets.dataset_id")
dataset_id: UUID = Field(default=None, foreign_key="dataset.dataset_id")

"""The dataset image path."""
file_path_abs: str = Field(default=None, unique=True)
Expand Down Expand Up @@ -77,15 +77,15 @@ class SampleTagLinkTable(SQLModel, table=True):
"""Model to define links between Sample and Tag Many-to-Many."""

sample_id: Optional[UUID] = Field(
default=None, foreign_key="samples.sample_id", primary_key=True
default=None, foreign_key="sample.sample_id", primary_key=True
)
tag_id: Optional[UUID] = Field(default=None, foreign_key="tags.tag_id", primary_key=True)
tag_id: Optional[UUID] = Field(default=None, foreign_key="tag.tag_id", primary_key=True)


class SampleTable(SampleBase, table=True):
"""This class defines the Sample model."""

__tablename__ = "samples"
__tablename__ = "sample"
sample_id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
updated_at: datetime = Field(
Expand Down
6 changes: 3 additions & 3 deletions lightly_studio/src/lightly_studio/models/sample_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ class SampleEmbeddingBase(SQLModel):

sample_embedding_id: UUID = Field(default_factory=uuid4, primary_key=True)

sample_id: UUID = Field(foreign_key="samples.sample_id")
sample_id: UUID = Field(foreign_key="sample.sample_id")

embedding_model_id: UUID = Field(foreign_key="embedding_models.embedding_model_id")
embedding_model_id: UUID = Field(foreign_key="embedding_model.embedding_model_id")
embedding: list[float] = Field(sa_column=Column(ARRAY(Float)))


Expand All @@ -33,5 +33,5 @@ class SampleEmbeddingCreate(SampleEmbeddingBase):
class SampleEmbeddingTable(SampleEmbeddingBase, table=True):
"""This class defines the SampleEmbedding model."""

__tablename__ = "sample_embeddings"
__tablename__ = "sample_embedding"
sample: SampleTable = Relationship(back_populates="embeddings")
2 changes: 1 addition & 1 deletion lightly_studio/src/lightly_studio/models/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class SettingView(SettingBase):
class SettingTable(SettingBase, table=True):
"""This class defines the Setting model."""

__tablename__ = "settings"
__tablename__ = "setting"
setting_id: UUID = Field(default_factory=uuid4, primary_key=True)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc), index=True)
updated_at: datetime = Field(
Expand Down
2 changes: 1 addition & 1 deletion lightly_studio/src/lightly_studio/models/tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class TagView(TagBase):
class TagTable(TagBase, table=True):
"""This class defines the Tag model."""

__tablename__ = "tags"
__tablename__ = "tag"
# ensure there can only be one tag named "lightly_studio" per dataset
__table_args__ = (
UniqueConstraint("dataset_id", "kind", "name", name="unique_name_constraint"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ def test_get(self) -> None:
expr = AND(a, b)
exprs = expr.get()
sql = str(exprs.compile(compile_kwargs={"literal_binds": True})).lower()
assert sql == "samples.height < 10 and samples.height > 20"
assert sql == "sample.height < 10 and sample.height > 20"

def test_get__single(self) -> None:
a = NumericalFieldExpression(field=SampleField.height, operator="<", value=10)
expr = AND(a)
exprs = expr.get()
sql = str(exprs.compile(compile_kwargs={"literal_binds": True})).lower()
assert sql == "samples.height < 10"
assert sql == "sample.height < 10"

def test_get__empty(self) -> None:
expr = AND()
Expand Down Expand Up @@ -67,14 +67,14 @@ def test_get(self) -> None:
expr = OR(a, b)
exprs = expr.get()
sql = str(exprs.compile(compile_kwargs={"literal_binds": True})).lower()
assert sql == "samples.height < 10 or samples.height > 20"
assert sql == "sample.height < 10 or sample.height > 20"

def test_get__single(self) -> None:
a = NumericalFieldExpression(field=SampleField.height, operator="<", value=10)
expr = OR(a)
exprs = expr.get()
sql = str(exprs.compile(compile_kwargs={"literal_binds": True})).lower()
assert sql == "samples.height < 10"
assert sql == "sample.height < 10"

def test_get__empty(self) -> None:
expr = OR()
Expand All @@ -94,4 +94,4 @@ def test_get(self) -> None:
expr = NOT(a)
exprs = expr.get()
sql = str(exprs.compile(compile_kwargs={"literal_binds": True})).lower()
assert sql == "samples.height >= 10"
assert sql == "sample.height >= 10"
10 changes: 5 additions & 5 deletions lightly_studio/tests/core/dataset_query/test_field_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_apply__less(self) -> None:
returned_query = query.where(expr.get())

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "where samples.height < 10" in sql
assert "where sample.height < 10" in sql

def test_apply__greater_equal(self) -> None:
query = select(SampleTable)
Expand All @@ -37,7 +37,7 @@ def test_apply__greater_equal(self) -> None:
returned_query = query.where(expr.get())

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "where samples.height >= 100" in sql
assert "where sample.height >= 100" in sql

@pytest.mark.parametrize(
("operator", "test_value", "expected_match"),
Expand Down Expand Up @@ -105,7 +105,7 @@ def test_apply__greater_than(self) -> None:
returned_query = query.where(expr.get())

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "where samples.created_at > '2023-01-01 12:00:00+00:00'" in sql
assert "where sample.created_at > '2023-01-01 12:00:00+00:00'" in sql

def test_apply__less_than_or_equal(self) -> None:
query = select(SampleTable)
Expand All @@ -118,7 +118,7 @@ def test_apply__less_than_or_equal(self) -> None:
returned_query = query.where(expr.get())

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "where samples.created_at <= '2024-06-15 10:30:00+00:00'" in sql
assert "where sample.created_at <= '2024-06-15 10:30:00+00:00'" in sql


class TestStringFieldExpression:
Expand All @@ -130,7 +130,7 @@ def test_apply__equal(self) -> None:
returned_query = query.where(expr.get())

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "where samples.file_name = 'test.jpg'" in sql
assert "where sample.file_name = 'test.jpg'" in sql

@pytest.mark.parametrize(
("operator", "test_value", "expected_match"),
Expand Down
6 changes: 3 additions & 3 deletions lightly_studio/tests/core/dataset_query/test_order_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_apply__default_ascending(self) -> None:
returned_query = order_by.apply(query)

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "order by samples.file_name asc" in sql
assert "order by sample.file_name asc" in sql

def test_apply__descending(self) -> None:
"""Test descending ordering via desc() method."""
Expand All @@ -26,7 +26,7 @@ def test_apply__descending(self) -> None:
returned_query = order_by.apply(query)

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "order by samples.file_name desc" in sql
assert "order by sample.file_name desc" in sql

def test_apply__desc_then_asc(self) -> None:
"""Test that desc().asc() returns to ascending order."""
Expand All @@ -36,4 +36,4 @@ def test_apply__desc_then_asc(self) -> None:
returned_query = order_by.apply(query)

sql = str(returned_query.compile(compile_kwargs={"literal_binds": True})).lower()
assert "order by samples.file_name asc" in sql
assert "order by sample.file_name asc" in sql
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def test_apply__sql(self) -> None:

# The current approach makes a subquery for the tags relationship.
assert "EXISTS (SELECT 1" in sql
assert "FROM tags, sampletaglinktable" in sql
assert "tags.name = 'car'" in sql
assert "FROM tag, sampletaglinktable" in sql
assert "tag.name = 'car'" in sql

def test_apply__can_be_chained(self, test_db: Session) -> None:
"""Test that multiple TagsContainsExpression can be applied to a query."""
Expand Down