11import os
2+ import warnings
23from dataclasses import asdict , dataclass
34from enum import Enum
4- from typing import Callable , List , Optional , Union
5+ from functools import partial
6+ from typing import Any , Dict , List , Optional , Union
57
68import matplotlib
79import matplotlib .pyplot as plt
1820
1921import scib_metrics
2022
23+ Kwargs = Dict [str , Any ]
24+ MetricType = Union [bool , Kwargs ]
25+
2126_LABELS = "labels"
2227_BATCH = "batch"
2328_X_PRE = "X_pre"
2429_METRIC_TYPE = "Metric Type"
2530_AGGREGATE_SCORE = "Aggregate score"
2631
32+ # Mapping of metric fn names to clean DataFrame column names
2733metric_name_cleaner = {
2834 "silhouette_label" : "Silhouette label" ,
2935 "silhouette_batch" : "Silhouette batch" ,
4046}
4147
4248
43- @dataclass
49+ @dataclass ( frozen = True )
4450class BioConservation :
4551 """Specification of bio conservation metrics to run in the pipeline.
4652
4753 Metrics can be included using a boolean flag. Custom keyword args can be
48- used by passing a partial callable of that metric here.
54+ used by passing a dictionary here. Keyword args should not set data-related
55+ parameters, such as `X` or `labels`.
4956 """
5057
51- isolated_labels : Union [ bool , Callable ] = True
52- nmi_ari_cluster_labels_leiden : Union [ bool , Callable ] = True
53- nmi_ari_cluster_labels_kmeans : Union [ bool , Callable ] = False
54- silhouette_label : Union [ bool , Callable ] = True
55- clisi_knn : Union [ bool , Callable ] = True
58+ isolated_labels : MetricType = True
59+ nmi_ari_cluster_labels_leiden : MetricType = True
60+ nmi_ari_cluster_labels_kmeans : MetricType = False
61+ silhouette_label : MetricType = True
62+ clisi_knn : MetricType = True
5663
5764
58- @dataclass
65+ @dataclass ( frozen = True )
5966class BatchCorrection :
6067 """Specification of which batch correction metrics to run in the pipeline.
6168
6269 Metrics can be included using a boolean flag. Custom keyword args can be
63- used by passing a partial callable of that metric here.
70+ used by passing a dictionary here. Keyword args should not set data-related
71+ parameters, such as `X` or `labels`.
6472 """
6573
66- silhouette_batch : Union [ bool , Callable ] = True
67- ilisi_knn : Union [ bool , Callable ] = True
68- kbet_per_label : Union [ bool , Callable ] = True
69- graph_connectivity : Union [ bool , Callable ] = True
70- pcr_comparison : Union [ bool , Callable ] = True
74+ silhouette_batch : MetricType = True
75+ ilisi_knn : MetricType = True
76+ kbet_per_label : MetricType = True
77+ graph_connectivity : MetricType = True
78+ pcr_comparison : MetricType = True
7179
7280
7381class MetricAnnDataAPI (Enum ):
@@ -138,6 +146,7 @@ def __init__(
138146 self ._emb_adatas = {}
139147 self ._neighbor_values = (15 , 50 , 90 )
140148 self ._prepared = False
149+ self ._benchmarked = False
141150 self ._batch_key = batch_key
142151 self ._label_key = label_key
143152 self ._n_jobs = n_jobs
@@ -183,6 +192,12 @@ def prepare(self) -> None:
183192
184193 def benchmark (self ) -> None :
185194 """Run the pipeline."""
195+ if self ._benchmarked :
196+ warnings .warn (
197+ "The benchmark has already been run. Running it again will overwrite the previous results." ,
198+ UserWarning ,
199+ )
200+
186201 if not self ._prepared :
187202 self .prepare ()
188203
@@ -193,13 +208,12 @@ def benchmark(self) -> None:
193208 for emb_key , ad in tqdm (self ._emb_adatas .items (), desc = "Embeddings" , position = 0 , colour = "green" ):
194209 pbar = tqdm (total = num_metrics , desc = "Metrics" , position = 1 , leave = False , colour = "blue" )
195210 for metric_type , metric_collection in self ._metric_collection_dict .items ():
196- for metric_name , use_metric in asdict (metric_collection ).items ():
197- if use_metric :
198- if isinstance (metric_name , str ):
199- metric_fn = getattr (scib_metrics , metric_name )
200- else :
201- # Callable in this case
202- metric_fn = use_metric
211+ for metric_name , use_metric_or_kwargs in asdict (metric_collection ).items ():
212+ if use_metric_or_kwargs :
213+ metric_fn = getattr (scib_metrics , metric_name )
214+ if isinstance (use_metric_or_kwargs , dict ):
215+ # Kwargs in this case
216+ metric_fn = partial (metric_fn , ** use_metric_or_kwargs )
203217 metric_value = getattr (MetricAnnDataAPI , metric_name )(ad , metric_fn )
204218 # nmi/ari metrics return a dict
205219 if isinstance (metric_value , dict ):
@@ -211,6 +225,8 @@ def benchmark(self) -> None:
211225 self ._results .loc [metric_name , _METRIC_TYPE ] = metric_type
212226 pbar .update (1 )
213227
228+ self ._benchmarked = True
229+
214230 def get_results (self , min_max_scale : bool = True , clean_names : bool = True ) -> pd .DataFrame :
215231 """Return the benchmarking results.
216232
@@ -242,6 +258,7 @@ def get_results(self, min_max_scale: bool = True, clean_names: bool = True) -> p
242258
243259 # Compute scores
244260 per_class_score = df .groupby (_METRIC_TYPE ).mean ().transpose ()
261+ # This is the default scIB weighting from the manuscript
245262 per_class_score ["Total" ] = 0.4 * per_class_score ["Batch correction" ] + 0.6 * per_class_score ["Bio conservation" ]
246263 df = pd .concat ([df .transpose (), per_class_score ], axis = 1 )
247264 df .loc [_METRIC_TYPE , per_class_score .columns ] = _AGGREGATE_SCORE
0 commit comments