Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
ad74ef3
init
selmanozleyen Jun 16, 2025
866ddc0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 18, 2025
0965a61
Merge branch 'main' into feature/filter_cells_new
selmanozleyen Jun 18, 2025
d53e20c
add tests and fix some bugs
selmanozleyen Jun 30, 2025
5fe17bd
fix the tests
selmanozleyen Jun 30, 2025
9472b72
relax timeout
selmanozleyen Jun 30, 2025
c9e48f9
relax time contraint even more
selmanozleyen Jun 30, 2025
a9587a2
docstrings
selmanozleyen Jul 9, 2025
87e5c8f
update the filter_cells based on the new spatialdata implementation
selmanozleyen Jul 9, 2025
719b696
correct the docstring
selmanozleyen Jul 9, 2025
484d849
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 9, 2025
94cceaf
push the current state
selmanozleyen Aug 20, 2025
85612d4
update filter_cells function
selmanozleyen Aug 21, 2025
267e297
Re-implementating co_occurrence() (#975)
wenjie1991 Aug 6, 2025
385b70c
chore: fix messed-up yaml formatting (#1024)
flying-sheep Aug 14, 2025
b33fa5e
Update CI (#1025)
flying-sheep Aug 15, 2025
75750a2
Moving to uv + hatch (#1029)
selmanozleyen Sep 4, 2025
d796105
bump version (#1031)
selmanozleyen Sep 4, 2025
77a97db
Replacing fixed size tuples as return types (#1043)
selmanozleyen Oct 1, 2025
2f65047
zarr v3 support (#1040)
LucaMarconato Oct 7, 2025
f8f358f
Fix: sepal numba compilation options changes the results significantl…
selmanozleyen Oct 16, 2025
0ff2ccb
Run notebooks in CI (#1013)
selmanozleyen Oct 21, 2025
7458bbb
Change niche flavor to cellcharter_simple and default distance = 3 (#…
marcovarrone Oct 26, 2025
8cf8001
Method to detect specimens in H&E images (#1044)
timtreis Oct 27, 2025
51c410c
[pre-commit.ci] pre-commit autoupdate (#997)
pre-commit-ci[bot] Oct 27, 2025
4b4b9c1
[pre-commit.ci] pre-commit autoupdate (#1050)
pre-commit-ci[bot] Oct 29, 2025
3add1f9
init
selmanozleyen Jun 16, 2025
3e1d89e
Merge branch 'main' into feature/filter_cells_new
selmanozleyen Nov 3, 2025
5ccbf33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,18 @@ Plotting
pl.extract
pl.var_by_distance

Preprocessing
~~~~~~~~~~~~~

.. module:: squidpy.pp
.. currentmodule:: squidpy

.. autosummary::
:toctree: api

pp.filter_cells


Reading
~~~~~~~

Expand Down
2 changes: 1 addition & 1 deletion docs/notebooks
Submodule notebooks updated 1 files
+1 −1 examples/index.rst
4 changes: 2 additions & 2 deletions src/squidpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from importlib import metadata
from importlib.metadata import PackageMetadata

from squidpy import datasets, experimental, gr, im, pl, read, tl
from squidpy import datasets, experimental, gr, im, pl, pp, read, tl

try:
md: PackageMetadata = metadata.metadata(__name__)
Expand All @@ -15,4 +15,4 @@

del metadata, md

__all__ = ["datasets", "experimental", "gr", "im", "pl", "read", "tl"]
__all__ = ["datasets", "experimental", "gr", "im", "pl", "pp", "read", "tl"]
5 changes: 5 additions & 0 deletions src/squidpy/pp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Basic pre-processing functions adapted from scanpy."""

from __future__ import annotations

from squidpy.pp._simple import filter_cells
197 changes: 197 additions & 0 deletions src/squidpy/pp/_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
from __future__ import annotations

import geopandas as gpd
import numpy as np
import scanpy as sc
import spatialdata as sd
from dask.dataframe import DataFrame as DaskDataFrame
from spatialdata import SpatialData, subset_sdata_by_table_mask
from spatialdata._logging import logger as logg
from spatialdata.models import (
get_table_keys,
points_dask_dataframe_to_geopandas,
points_geopandas_to_dask_dataframe,
)


def filter_cells(
data: ad.AnnData | sd.SpatialData,
tables: list[str] | str | None = None,
min_counts: int | None = None,
min_genes: int | None = None,
max_counts: int | None = None,
max_genes: int | None = None,
inplace: bool = True,
filter_labels: bool = True,
) -> sd.SpatialData | None:
"""\
Squidpy's implementation of :func:`scanpy.pp.filter_cells` for :class:`anndata.AnnData` and :class:`spatialdata.SpatialData` objects.
For :class:`spatialdata.SpatialData` objects, this function filters the following elements:


- labels: filtered based on the values of the images which are assumed to be the instance_id.
- points: filtered based on the index which is assumed to be the instance_id.
- shapes: filtered based on the instance_id column.


See :func:`scanpy.pp.filter_cells` for more details regarding the filtering
behavior.

Parameters
----------
data
:class:`spatialdata.SpatialData` object.
tables
If :class:`spatialdata.SpatialData` object, the tables to filter. If `None`, all tables are filtered.
min_counts
Minimum number of counts required for a cell to pass filtering.
min_genes
Minimum number of genes expressed required for a cell to pass filtering.
max_counts
Maximum number of counts required for a cell to pass filtering.
max_genes
Maximum number of genes expressed required for a cell to pass filtering.
inplace
Perform computation inplace or return result.
filter_labels
Whether to filter labels. If `True`, then labels are filtered based on the instance_id column.

Returns
-------
If `inplace` then returns `None`, otherwise returns the filtered :class:`spatialdata.SpatialData` object.
"""
if not isinstance(data, sd.SpatialData):
raise ValueError(
f"Expected `SpatialData`, found `{type(data)}` instead. Perhaps you want to use `scanpy.pp.filter_cells` instead."
)

return _filter_cells_spatialdata(data, tables, min_counts, min_genes, max_counts, max_genes, inplace, filter_labels)


def _get_only_annotated_shape(sdata: sd.SpatialData, table_name: str) -> str | None:
table = sdata.tables[table_name]

# only one shape needs to be annotated to filter points within it
# other annotations can't be points

regions, _, _ = get_table_keys(table)
if len(regions) == 0:
return None

if isinstance(regions, str):
regions = [regions]

res = None
for r in regions:
if r in sdata.points:
return None
if r in sdata.shapes:
if res is not None:
return None
res = r

return res


def _annotated_points_by_shape_membership(
sdata: SpatialData,
point_key: str,
shape_key: str,
) -> DaskDataFrame:
"""Annotate points by shape membership.

Parameters
----------
sdata
The SpatialData object to annotate.
point_key
The key of the points to annotate.
shape_key
The key of the shapes to annotate.

Returns
-------
The annotated points.
"""
points = sdata.points[point_key]
shapes = sdata.shapes[shape_key]
points_gdf = points_dask_dataframe_to_geopandas(points)
res = points_gdf.sjoin(shapes, how="left", predicate="within")
return points_geopandas_to_dask_dataframe(res)


def _filter_cells_spatialdata(
data: sd.SpatialData,
tables: list[str] | str | None = None,
min_counts: int | None = None,
min_genes: int | None = None,
max_counts: int | None = None,
max_genes: int | None = None,
inplace: bool = True,
filter_labels: bool = True,
) -> sd.SpatialData | None:
if isinstance(tables, str):
tables = [tables]
elif tables is None:
tables = list(data.tables.keys())

if len(tables) == 0:
raise ValueError("Expected at least one table to be filtered, found `0`")

if not all(t in data.tables for t in tables):
raise ValueError(f"Expected all tables to be in `{data.tables.keys()}`.")

for t in tables:
if "spatialdata_attrs" not in data.tables[t].uns:
raise ValueError(f"Table `{t}` does not have 'spatialdata_attrs' to indicate what it annotates.")

if not inplace:
logg.warning(
"Creating a deepcopy of the SpatialData object, depending on the size of the object this can take a while."
)
data_out = sd.deepcopy(data)
else:
data_out = data

for t in tables:
table_old = data_out.tables[t]
mask_filtered, _ = sc.pp.filter_cells(
table_old,
min_counts=min_counts,
min_genes=min_genes,
max_counts=max_counts,
max_genes=max_genes,
inplace=False,
)
if mask_filtered.sum() == 0:
raise ValueError(f"Filter results in empty table when filtering table `{t}`.")
sdata_filtered = subset_sdata_by_table_mask(sdata=data_out, table_name=t, mask=mask_filtered)
data_out.tables[t] = sdata_filtered.tables[t]
for k in list(sdata_filtered.points.keys()):
data_out.points[k] = sdata_filtered.points[k]
for k in list(sdata_filtered.shapes.keys()):
data_out.shapes[k] = sdata_filtered.shapes[k]
if filter_labels:
for k in list(sdata_filtered.labels.keys()):
data_out.labels[k] = sdata_filtered.labels[k]
shape_name = _get_only_annotated_shape(data_out, t)
if shape_name is not None:
for p in data_out.points:
_, _, instance_key = get_table_keys(table_old)
shape_index_name = data_out.shapes[shape_name].index.name
new_points = _annotated_points_by_shape_membership(
sdata=data_out,
shape_key=shape_name,
point_key=p,
)
shape_index_name += "_right"
removed_instance_ids = list(np.unique(table_old.obs[instance_key][~mask_filtered]))
# drop points that are not in any shape
new_points = new_points.dropna()
# drop points that are in the removed_instance_ids
new_points = new_points[~new_points[shape_index_name].isin(removed_instance_ids)]
data_out.points[p] = new_points

if inplace:
return None
return data_out
60 changes: 60 additions & 0 deletions tests/preprocessing/test_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from __future__ import annotations

import anndata as ad
import numpy as np
import pytest
import scanpy as sc
from spatialdata.datasets import blobs_annotating_element

import squidpy as sq


def _make_sdata(name: str, num_counts: int, count_value: int):
assert num_counts <= 5, "num_counts must be less than 5"
sdata_temp = blobs_annotating_element(name)
m, _ = sdata_temp.tables["table"].shape
n = m
X = np.zeros((m, n))
# random choice of row
row_indices = np.random.choice(m, num_counts, replace=False)
col_indices = np.random.choice(n, num_counts, replace=False)
X[row_indices, col_indices] = count_value

sdata_temp.tables["table"] = ad.AnnData(
X=X,
obs=sdata_temp.tables["table"].obs,
var={"gene": ["gene" for _ in range(n)]},
uns=sdata_temp.tables["table"].uns,
)
return sdata_temp


@pytest.mark.parametrize("name", ["blobs_labels", "blobs_circles", "blobs_points", "blobs_multiscale_labels"])
def test_filter_cells(name: str):
filtered_cells = 3
sdata = _make_sdata(name, num_counts=filtered_cells, count_value=100)
num_cells = sdata.tables["table"].shape[0]
adata_copy = sdata.tables["table"].copy()
sc.pp.filter_cells(adata_copy, max_counts=50, inplace=True)
sq.pp.filter_cells(sdata, max_counts=50, inplace=True, filter_labels=True)

assert np.all(sdata.tables["table"].X == adata_copy.X), "Filtered cells are not the same as scanpy"
assert np.all(sdata.tables["table"].obs["instance_id"] == adata_copy.obs["instance_id"]), (
"Filtered cells are not the same as scanpy"
)
assert sdata.tables["table"].shape[0] == (num_cells - filtered_cells), (
f"Expected {num_cells - filtered_cells} cells, got {sdata.tables['table'].shape[0]}"
)

if name == "blobs_labels":
unique_labels = np.unique(adata_copy.obs["instance_id"])
unique_labels_sdata = np.unique(sdata.labels["blobs_labels"].data.compute())
assert set(unique_labels) == set(unique_labels_sdata).difference([0]), (
f"Filtered labels {unique_labels} are not the same as scanpy {unique_labels_sdata}"
)


def test_filter_cells_empty_fail():
sdata = _make_sdata("blobs_labels", num_counts=5, count_value=200)
with pytest.raises(ValueError, match="Filter results in empty table when filtering table `table`."):
sq.pp.filter_cells(sdata, max_counts=100, inplace=True)
2 changes: 1 addition & 1 deletion tests/utils/test_parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def func(request) -> Callable:
# in case of failure.


@pytest.mark.timeout(30)
@pytest.mark.timeout(50)
@pytest.mark.parametrize(
"backend",
[
Expand Down
Loading