Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/tracksdata/array/_graph_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _validate_shape(
"""Helper function to validate the shape argument."""
if shape is None:
try:
shape = graph.metadata()["shape"]
shape = graph.metadata["shape"]
except KeyError as e:
raise KeyError(
f"`shape` is required to `{func_name}`. "
Expand Down
2 changes: 1 addition & 1 deletion src/tracksdata/functional/_test/test_napari.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_napari_conversion(metadata_shape: bool) -> None:

shape = (2, 10, 22, 32)
if metadata_shape:
graph.update_metadata(shape=shape)
graph.metadata.update(shape=shape)
arg_shape = None
else:
arg_shape = shape
Expand Down
4 changes: 2 additions & 2 deletions src/tracksdata/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Graph backends for representing tracking data as directed graphs in memory or on disk."""

from tracksdata.graph._base_graph import BaseGraph
from tracksdata.graph._base_graph import BaseGraph, MetadataView
from tracksdata.graph._graph_view import GraphView
from tracksdata.graph._rustworkx_graph import IndexedRXGraph, RustWorkXGraph
from tracksdata.graph._sql_graph import SQLGraph

InMemoryGraph = RustWorkXGraph

__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "RustWorkXGraph", "SQLGraph"]
__all__ = ["BaseGraph", "GraphView", "InMemoryGraph", "IndexedRXGraph", "MetadataView", "RustWorkXGraph", "SQLGraph"]
159 changes: 128 additions & 31 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,75 @@
T = TypeVar("T", bound="BaseGraph")


class MetadataView(dict[str, Any]):
"""Dictionary-like metadata view that syncs mutations back to the graph."""

_MISSING = object()

def __init__(
self,
graph: "BaseGraph",
data: dict[str, Any],
*,
is_public: bool = True,
) -> None:
super().__init__(data)
self._graph = graph
self._is_public = is_public

def __setitem__(self, key: str, value: Any) -> None:
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: value})
super().__setitem__(key, value)

def __delitem__(self, key: str) -> None:
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().__delitem__(key)

def pop(self, key: str, default: Any = _MISSING) -> Any:
self._graph._validate_metadata_key(key, is_public=self._is_public)

if key not in self:
if default is self._MISSING:
raise KeyError(key)
return default

value = super().__getitem__(key)
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().pop(key, None)
return value

def popitem(self) -> tuple[str, Any]:
key, value = super().popitem()
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
return key, value

def clear(self) -> None:
keys = list(self.keys())
for key in keys:
self._graph._remove_metadata_with_validation(key, is_public=self._is_public)
super().clear()

def setdefault(self, key: str, default: Any = None) -> Any:
if key in self:
return super().__getitem__(key)
self._graph._set_metadata_with_validation(is_public=self._is_public, **{key: default})
super().__setitem__(key, default)
return default

def update(self, *args, **kwargs) -> None:
updates = dict(*args, **kwargs)
if updates:
self._graph._set_metadata_with_validation(is_public=self._is_public, **updates)
super().update(updates)


class BaseGraph(abc.ABC):
"""
Base class for a graph backend.
"""

_PRIVATE_METADATA_PREFIX = "__private_"

node_added = Signal(int)
node_removed = Signal(int)

Expand Down Expand Up @@ -1186,7 +1250,8 @@ def from_other(cls: type[T], other: "BaseGraph", **kwargs) -> T:
node_attrs = node_attrs.drop(DEFAULT_ATTR_KEYS.NODE_ID)

graph = cls(**kwargs)
graph.update_metadata(**other.metadata())
graph.metadata.update(other.metadata)
graph._private_metadata.update(other._private_metadata_for_copy())

current_node_attr_schemas = graph._node_attr_schemas()
for k, v in other._node_attr_schemas().items():
Expand Down Expand Up @@ -1786,7 +1851,8 @@ def to_geff(
for k, v in edge_attrs.to_dict().items()
}

td_metadata = self.metadata().copy()
td_metadata = self.metadata.copy()
td_metadata.update(self._private_metadata_for_copy())
td_metadata.pop("geff", None) # avoid geff being written multiple times
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Private metadata is lost during conversion to geff. Is this correct?
Wouldn't it make sense to keep it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, thanks for notifying this. It should survive.

Copy link
Contributor Author

@yfukai yfukai Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! I used _private_metadata_for_copy so that the backends can decide which metadata to send, since SQLGraph will use the private metadata to persistently store the AttrSchema, which should not be used when creating the graph by from_other etc., and thus should be excluded to avoid confusion. I'd be happy to hear your opinion.


geff_metadata = geff.GeffMetadata(
Expand Down Expand Up @@ -1824,57 +1890,88 @@ def to_geff(
zarr_format=zarr_format,
)

@abc.abstractmethod
def metadata(self) -> dict[str, Any]:
@property
def metadata(self) -> MetadataView:
"""
Return the metadata of the graph.

Returns
-------
dict[str, Any]
MetadataView
The metadata of the graph as a dictionary.

Examples
--------
```python
metadata = graph.metadata()
metadata = graph.metadata
print(metadata["shape"])
```
"""
return MetadataView(
graph=self,
data={k: v for k, v in self._metadata().items() if not self._is_private_metadata_key(k)},
is_public=True,
)

@abc.abstractmethod
def update_metadata(self, **kwargs) -> None:
@property
def _private_metadata(self) -> MetadataView:
return MetadataView(
graph=self,
data={k: v for k, v in self._metadata().items() if self._is_private_metadata_key(k)},
is_public=False,
)

def _private_metadata_for_copy(self) -> dict[str, Any]:
"""
Set or update metadata for the graph.
Return private metadata entries that should be propagated by `from_other` or `to_geff`.
Backends can override this to exclude backend-specific private metadata.
"""
return dict(self._private_metadata)

Parameters
----------
**kwargs : Any
The metadata items to set by key. Values will be stored as JSON.
@classmethod
def _is_private_metadata_key(cls, key: str) -> bool:
return key.startswith(cls._PRIVATE_METADATA_PREFIX)

def _validate_metadata_key(self, key: str, *, is_public: bool) -> None:
if not isinstance(key, str):
raise TypeError(f"Metadata key must be a string. Got {type(key)}.")
is_private_key = self._is_private_metadata_key(key)
if is_public and is_private_key:
raise ValueError(f"Metadata key '{key}' is reserved for internal use.")
if not is_public and not is_private_key:
raise ValueError(
f"Metadata key '{key}' is not private. Private metadata keys must start with "
f"'{self._PRIVATE_METADATA_PREFIX}'."
)

Examples
--------
```python
graph.update_metadata(shape=[1, 25, 25], path="path/to/image.ome.zarr")
graph.update_metadata(description="Tracking data from experiment 1")
```
"""
def _validate_metadata_keys(self, keys: Sequence[str], *, is_public: bool) -> None:
for key in keys:
self._validate_metadata_key(key, is_public=is_public)

def _set_metadata_with_validation(self, is_public: bool = True, **kwargs) -> None:
self._validate_metadata_keys(kwargs.keys(), is_public=is_public)
self._update_metadata(**kwargs)

def _remove_metadata_with_validation(self, key: str, *, is_public: bool = True) -> None:
self._validate_metadata_key(key, is_public=is_public)
self._remove_metadata(key)

@abc.abstractmethod
def remove_metadata(self, key: str) -> None:
def _metadata(self) -> dict[str, Any]:
"""
Return the full metadata including private keys.
"""
Remove a metadata key from the graph.

Parameters
----------
key : str
The key of the metadata to remove.
@abc.abstractmethod
def _update_metadata(self, **kwargs) -> None:
"""
Backend-specific metadata update implementation without public key validation.
"""

Examples
--------
```python
graph.remove_metadata("shape")
```
@abc.abstractmethod
def _remove_metadata(self, key: str) -> None:
"""
Backend-specific metadata removal implementation without public key validation.
"""

def to_traccuracy_graph(self, array_view_kwargs: dict[str, Any] | None = None) -> "TrackingGraph":
Expand Down
12 changes: 6 additions & 6 deletions src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,11 +847,11 @@ def copy(self, **kwargs) -> "GraphView":
"Use `detach` to create a new reference-less graph with the same nodes and edges."
)

def metadata(self) -> dict[str, Any]:
return self._root.metadata()
def _metadata(self) -> dict[str, Any]:
return self._root._metadata()

def update_metadata(self, **kwargs) -> None:
self._root.update_metadata(**kwargs)
def _update_metadata(self, **kwargs) -> None:
self._root._update_metadata(**kwargs)

def remove_metadata(self, key: str) -> None:
self._root.remove_metadata(key)
def _remove_metadata(self, key: str) -> None:
self._root._remove_metadata(key)
8 changes: 4 additions & 4 deletions src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ def __init__(self, rx_graph: rx.PyDiGraph | None = None) -> None:

elif not isinstance(self._graph.attrs, dict):
LOG.warning(
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata()`",
"previous attribute %s will be added to key 'old_attrs' of `graph.metadata`",
self._graph.attrs,
)
self._graph.attrs = {
Expand Down Expand Up @@ -1499,13 +1499,13 @@ def edge_id(self, source_id: int, target_id: int) -> int:
"""
return self.rx_graph.get_edge_data(source_id, target_id)[DEFAULT_ATTR_KEYS.EDGE_ID]

def metadata(self) -> dict[str, Any]:
def _metadata(self) -> dict[str, Any]:
return self._graph.attrs

def update_metadata(self, **kwargs) -> None:
def _update_metadata(self, **kwargs) -> None:
self._graph.attrs.update(kwargs)

def remove_metadata(self, key: str) -> None:
def _remove_metadata(self, key: str) -> None:
self._graph.attrs.pop(key, None)

def edge_list(self) -> list[list[int, int]]:
Expand Down
6 changes: 3 additions & 3 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1992,19 +1992,19 @@ def remove_edge(
raise ValueError(f"Edge {edge_id} does not exist in the graph.")
session.commit()

def metadata(self) -> dict[str, Any]:
def _metadata(self) -> dict[str, Any]:
with Session(self._engine) as session:
result = session.query(self.Metadata).all()
return {row.key: row.value for row in result}

def update_metadata(self, **kwargs) -> None:
def _update_metadata(self, **kwargs) -> None:
with Session(self._engine) as session:
for key, value in kwargs.items():
metadata_entry = self.Metadata(key=key, value=value)
session.merge(metadata_entry)
session.commit()

def remove_metadata(self, key: str) -> None:
def _remove_metadata(self, key: str) -> None:
with Session(self._engine) as session:
session.query(self.Metadata).filter(self.Metadata.key == key).delete()
session.commit()
Expand Down
Loading
Loading