Skip to content

Commit 84d3a77

Browse files
authored
fix kmeans++ initialization, rename class to Kmeans (#81)
* fix kmeans++ * changelog and bump version * finalize and update tutorial * changelog * update large scale tutorial
1 parent 4a04662 commit 84d3a77

File tree

10 files changed

+393
-186
lines changed

10 files changed

+393
-186
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ and this project adheres to [Semantic Versioning][].
88
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
99
[semantic versioning]: https://semver.org/spec/v2.0.0.html
1010

11-
## 0.2.1 (2022-02-16)
11+
## 0.3.0 (2022-02-16)
1212

13+
- Rename `KmeansJax` to `Kmeans` and fix ++ initialization, use Kmeans as default in benchmarker instead of Leiden ([#81][])
1314
- Warn about joblib, add progress bar postfix str ([#80][])
1415

16+
[#81]: https://github.com/YosefLab/scib-metrics/pull/81
1517
[#80]: https://github.com/YosefLab/scib-metrics/pull/80
1618

1719
## 0.2.0 (2022-02-02)

docs/api.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ scib_metrics.ilisi_knn(...)
6161
utils.cdist
6262
utils.pdist_squareform
6363
utils.silhouette_samples
64-
utils.KMeansJax
64+
utils.KMeans
6565
utils.pca
6666
utils.principal_component_regression
6767
utils.one_hot

docs/notebooks/large_scale.ipynb

Lines changed: 57 additions & 59 deletions
Large diffs are not rendered by default.

docs/notebooks/lung_example.ipynb

Lines changed: 287 additions & 79 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ requires = ["hatchling"]
55

66
[project]
77
name = "scib-metrics"
8-
version = "0.2.1"
8+
version = "0.3.0"
99
description = "Accelerated and Python-only scIB metrics"
1010
readme = "README.md"
1111
requires-python = ">=3.8"

src/scib_metrics/_nmi_ari.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88
from sklearn.metrics.cluster import adjusted_rand_score, normalized_mutual_info_score
99
from sklearn.utils import check_array
1010

11-
from .utils import KMeansJax, check_square
11+
from .utils import KMeans, check_square
1212

1313
logger = logging.getLogger(__name__)
1414

1515

1616
def _compute_clustering_kmeans(X: np.ndarray, n_clusters: int) -> np.ndarray:
17-
kmeans = KMeansJax(n_clusters)
17+
kmeans = KMeans(n_clusters)
1818
kmeans.fit(X)
1919
return kmeans.labels_
2020

src/scib_metrics/benchmark/_core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ class BioConservation:
5656
"""
5757

5858
isolated_labels: MetricType = True
59-
nmi_ari_cluster_labels_leiden: MetricType = True
60-
nmi_ari_cluster_labels_kmeans: MetricType = False
59+
nmi_ari_cluster_labels_leiden: MetricType = False
60+
nmi_ari_cluster_labels_kmeans: MetricType = True
6161
silhouette_label: MetricType = True
6262
clisi_knn: MetricType = True
6363

src/scib_metrics/utils/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._diffusion_nn import diffusion_nn
22
from ._dist import cdist, pdist_squareform
3-
from ._kmeans import KMeansJax
3+
from ._kmeans import KMeans
44
from ._lisi import compute_simpson_index
55
from ._pca import pca
66
from ._pcr import principal_component_regression
@@ -12,7 +12,7 @@
1212
"cdist",
1313
"pdist_squareform",
1414
"get_ndarray",
15-
"KMeansJax",
15+
"KMeans",
1616
"pca",
1717
"principal_component_regression",
1818
"one_hot",

src/scib_metrics/utils/_kmeans.py

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
from sklearn.utils import check_array
88

99
from .._types import IntOrKey
10-
from ._dist import cdist, pdist_squareform
10+
from ._dist import cdist
1111
from ._utils import get_ndarray, validate_seed
1212

1313

14-
def _initialize_random(X: jnp.ndarray, n_clusters: int, pdists: jnp.ndarray, key: jax.random.KeyArray) -> jnp.ndarray:
14+
def _initialize_random(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
1515
"""Initialize cluster centroids randomly."""
1616
n_obs = X.shape[0]
1717
indices = jax.random.choice(key, n_obs, (n_clusters,), replace=False)
@@ -20,38 +20,39 @@ def _initialize_random(X: jnp.ndarray, n_clusters: int, pdists: jnp.ndarray, key
2020

2121

2222
@partial(jax.jit, static_argnums=1)
23-
def _initialize_plus_plus(
24-
X: jnp.ndarray, n_clusters: int, pdists: jnp.ndarray, key: jax.random.KeyArray
25-
) -> jnp.ndarray:
23+
def _initialize_plus_plus(X: jnp.ndarray, n_clusters: int, key: jax.random.KeyArray) -> jnp.ndarray:
2624
"""Initialize cluster centroids with k-means++ algorithm."""
27-
28-
def _init(key, pdists):
29-
key, subkey = jax.random.split(key)
30-
n_obs = pdists.shape[0]
31-
# sample first centroid uniformly at random
32-
idx = jax.random.choice(subkey, n_obs)
33-
centroids = jnp.full((n_clusters,), -1, dtype=jnp.int32).at[0].set(idx)
34-
mask = jnp.zeros((n_obs,), dtype=jnp.bool_).at[idx].set(True)
35-
return centroids, mask, pdists, key
36-
37-
def _step(state):
38-
centroids, mask, pdists, key = state
39-
key, subkey = jax.random.split(key)
40-
n_obs = pdists.shape[0]
41-
# d(x) = min_{mu in centroids} ||x - mu||^2, d(x) = 0 if x in centroids
42-
probs = jnp.where(mask, 0, jnp.min(jnp.where(mask, pdists, jnp.inf), axis=1) ** 2)
43-
# sample with probability ~ d(x)
44-
idx = jax.random.choice(subkey, n_obs, p=probs / jnp.sum(probs))
45-
centroids = centroids.at[jnp.sum(mask)].set(idx)
46-
mask = mask.at[idx].set(True)
47-
return centroids, mask, pdists, key
48-
49-
def _convergence(state):
50-
_, mask, _, _ = state
51-
return jnp.sum(mask) < n_clusters
52-
53-
centroids, _, _, _ = jax.lax.while_loop(_convergence, _step, _init(key, pdists))
54-
return X[centroids]
25+
n_obs = X.shape[0]
26+
key, subkey = jax.random.split(key)
27+
initial_centroid_idx = jax.random.choice(subkey, n_obs, (1,), replace=False)
28+
initial_centroid = X[initial_centroid_idx].ravel()
29+
dist_sq = jnp.square(cdist(initial_centroid[jnp.newaxis, :], X)).ravel()
30+
initial_state = {"min_dist_sq": dist_sq, "centroid": initial_centroid, "key": key}
31+
n_local_trials = 2 + int(np.log(n_clusters))
32+
33+
def _step(state, _):
34+
prob = state["min_dist_sq"] / jnp.sum(state["min_dist_sq"])
35+
# note that observations already chosen as centers will have 0 probability
36+
# and will not be chosen again
37+
state["key"], subkey = jax.random.split(state["key"])
38+
next_centroid_idx_candidates = jax.random.choice(subkey, n_obs, (n_local_trials,), replace=False, p=prob)
39+
next_centroid_candidates = X[next_centroid_idx_candidates]
40+
# candidates by observations
41+
dist_sq_candidates = jnp.square(cdist(next_centroid_candidates, X))
42+
dist_sq_candidates = jnp.minimum(state["min_dist_sq"][jnp.newaxis, :], dist_sq_candidates)
43+
candidates_pot = dist_sq_candidates.sum(axis=1)
44+
45+
# Decide which candidate is the best
46+
best_candidate = jnp.argmin(candidates_pot)
47+
min_dist_sq = dist_sq_candidates[best_candidate]
48+
best_candidate = next_centroid_idx_candidates[best_candidate]
49+
50+
state["min_dist_sq"] = min_dist_sq.ravel()
51+
state["centroid"] = X[best_candidate].ravel()
52+
return state, state["centroid"]
53+
54+
_, centroids = jax.lax.scan(_step, initial_state, jnp.arange(n_clusters - 1))
55+
return centroids
5556

5657

5758
@jax.jit
@@ -62,7 +63,7 @@ def _get_dist_labels(X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
6263
return dist, labels
6364

6465

65-
class KMeansJax:
66+
class KMeans:
6667
"""Jax implementation of :class:`sklearn.cluster.KMeans`.
6768
6869
This implementation is limited to Euclidean distance.
@@ -91,7 +92,7 @@ class KMeansJax:
9192
def __init__(
9293
self,
9394
n_clusters: int = 8,
94-
init: Literal["k-means++", "random"] = "random",
95+
init: Literal["k-means++", "random"] = "k-means++",
9596
n_init: int = 10,
9697
max_iter: int = 300,
9798
tol: float = 1e-4,
@@ -122,7 +123,6 @@ def fit(self, X: np.ndarray):
122123
return self
123124

124125
def _fit(self, X: np.ndarray):
125-
self._pdists = pdist_squareform(X)
126126
all_centroids, all_inertias = jax.lax.map(
127127
lambda key: self._kmeans_full_run(X, key), jax.random.split(self.seed, self.n_init)
128128
)
@@ -131,7 +131,6 @@ def _fit(self, X: np.ndarray):
131131
self.inertia_ = get_ndarray(all_inertias[i])
132132
_, labels = _get_dist_labels(X, self.cluster_centroids_)
133133
self.labels_ = get_ndarray(labels)
134-
del self._pdists
135134

136135
@partial(jax.jit, static_argnums=(0,))
137136
def _kmeans_full_run(self, X: jnp.ndarray, key: jnp.ndarray) -> jnp.ndarray:
@@ -169,7 +168,7 @@ def _kmeans_convergence(state):
169168
cond2 = n_iter > self.max_iter
170169
return jnp.logical_or(cond1, cond2)[0]
171170

172-
centroids = self._initialize(X, self.n_clusters, self._pdists, key)
171+
centroids = self._initialize(X, self.n_clusters, key)
173172
# centroids, new_inertia, old_inertia, n_iter
174173
state = (centroids, jnp.inf, jnp.inf, jnp.array([0.0]))
175174
state = _kmeans_step(state)

tests/test_metrics.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_isolated_labels():
108108

109109
def test_kmeans():
110110
X, _ = dummy_x_labels()
111-
kmeans = scib_metrics.utils.KMeansJax(2)
111+
kmeans = scib_metrics.utils.KMeans(2)
112112
kmeans.fit(X)
113113
assert kmeans.labels_.shape == (X.shape[0],)
114114

0 commit comments

Comments
 (0)