33from dataclasses import asdict , dataclass
44from enum import Enum
55from functools import partial
6- from typing import Any , Dict , List , Optional , Union
6+ from typing import Any , Callable , Dict , List , Optional , Union
77
88import matplotlib
99import matplotlib .pyplot as plt
1414from plottable import ColumnDefinition , Table
1515from plottable .cmap import normed_cmap
1616from plottable .plots import bar
17- from pynndescent import NNDescent
1817from sklearn .preprocessing import MinMaxScaler
1918from tqdm import tqdm
2019
2120import scib_metrics
21+ from scib_metrics .nearest_neighbors import NeighborsOutput , pynndescent
2222
2323Kwargs = Dict [str , Any ]
2424MetricType = Union [bool , Kwargs ]
@@ -156,8 +156,17 @@ def __init__(
156156 "Batch correction" : self ._batch_correction_metrics ,
157157 }
158158
159- def prepare (self ) -> None :
160- """Prepare the data for benchmarking."""
159+ def prepare (self , neighbor_computer : Optional [Callable [[np .ndarray , int ], NeighborsOutput ]] = None ) -> None :
160+ """Prepare the data for benchmarking.
161+
162+ Parameters
163+ ----------
164+ neighbor_computer
165+ Function that computes the neighbors of the data. If `None`, the neighbors will be computed
166+ with :func:`~scib_metrics.utils.nearest_neighbors.pynndescent`. The function should take as input
167+ the data and the number of neighbors to compute and return a :class:`~scib_metrics.utils.nearest_neighbors.NeighborsOutput`
168+ object.
169+ """
161170 # Compute PCA
162171 if self ._pre_integrated_embedding_obsm_key is None :
163172 # This is how scib does it
@@ -173,24 +182,13 @@ def prepare(self) -> None:
173182
174183 # Compute neighbors
175184 for ad in tqdm (self ._emb_adatas .values (), desc = "Computing neighbors" ):
176- # Variables from umap (https://github.com/lmcinnes/umap/blob/3f19ce19584de4cf99e3d0ae779ba13a57472cd9/umap/umap_.py#LL326-L327)
177- # which is used by scanpy under the hood
178- n_trees = min (64 , 5 + int (round ((ad .X .shape [0 ]) ** 0.5 / 20.0 )))
179- n_iters = max (5 , int (round (np .log2 (ad .X .shape [0 ]))))
180- max_candidates = 60
181-
182- knn_search_index = NNDescent (
183- ad .X ,
184- n_neighbors = max (self ._neighbor_values ),
185- random_state = 0 ,
186- low_memory = True ,
187- n_jobs = self ._n_jobs ,
188- compressed = False ,
189- n_trees = n_trees ,
190- n_iters = n_iters ,
191- max_candidates = max_candidates ,
192- )
193- indices , distances = knn_search_index .neighbor_graph
185+ if neighbor_computer is not None :
186+ neigh_output = neighbor_computer (ad .X , max (self ._neighbor_values ))
187+ else :
188+ neigh_output = pynndescent (
189+ ad .X , n_neighbors = max (self ._neighbor_values ), random_state = 0 , n_jobs = self ._n_jobs
190+ )
191+ indices , distances = neigh_output .indices , neigh_output .distances
194192 for n in self ._neighbor_values :
195193 sp_distances , sp_conns = sc .neighbors ._compute_connectivities_umap (
196194 indices [:, :n ], distances [:, :n ], ad .n_obs , n_neighbors = n
0 commit comments