77from sklearn .utils import check_array
88
99from .._types import IntOrKey
10- from ._dist import cdist , pdist_squareform
10+ from ._dist import cdist
1111from ._utils import get_ndarray , validate_seed
1212
1313
14- def _initialize_random (X : jnp .ndarray , n_clusters : int , pdists : jnp . ndarray , key : jax .random .KeyArray ) -> jnp .ndarray :
14+ def _initialize_random (X : jnp .ndarray , n_clusters : int , key : jax .random .KeyArray ) -> jnp .ndarray :
1515 """Initialize cluster centroids randomly."""
1616 n_obs = X .shape [0 ]
1717 indices = jax .random .choice (key , n_obs , (n_clusters ,), replace = False )
@@ -20,38 +20,39 @@ def _initialize_random(X: jnp.ndarray, n_clusters: int, pdists: jnp.ndarray, key
2020
2121
2222@partial (jax .jit , static_argnums = 1 )
23- def _initialize_plus_plus (
24- X : jnp .ndarray , n_clusters : int , pdists : jnp .ndarray , key : jax .random .KeyArray
25- ) -> jnp .ndarray :
23+ def _initialize_plus_plus (X : jnp .ndarray , n_clusters : int , key : jax .random .KeyArray ) -> jnp .ndarray :
2624 """Initialize cluster centroids with k-means++ algorithm."""
27-
28- def _init (key , pdists ):
29- key , subkey = jax .random .split (key )
30- n_obs = pdists .shape [0 ]
31- # sample first centroid uniformly at random
32- idx = jax .random .choice (subkey , n_obs )
33- centroids = jnp .full ((n_clusters ,), - 1 , dtype = jnp .int32 ).at [0 ].set (idx )
34- mask = jnp .zeros ((n_obs ,), dtype = jnp .bool_ ).at [idx ].set (True )
35- return centroids , mask , pdists , key
36-
37- def _step (state ):
38- centroids , mask , pdists , key = state
39- key , subkey = jax .random .split (key )
40- n_obs = pdists .shape [0 ]
41- # d(x) = min_{mu in centroids} ||x - mu||^2, d(x) = 0 if x in centroids
42- probs = jnp .where (mask , 0 , jnp .min (jnp .where (mask , pdists , jnp .inf ), axis = 1 ) ** 2 )
43- # sample with probability ~ d(x)
44- idx = jax .random .choice (subkey , n_obs , p = probs / jnp .sum (probs ))
45- centroids = centroids .at [jnp .sum (mask )].set (idx )
46- mask = mask .at [idx ].set (True )
47- return centroids , mask , pdists , key
48-
49- def _convergence (state ):
50- _ , mask , _ , _ = state
51- return jnp .sum (mask ) < n_clusters
52-
53- centroids , _ , _ , _ = jax .lax .while_loop (_convergence , _step , _init (key , pdists ))
54- return X [centroids ]
25+ n_obs = X .shape [0 ]
26+ key , subkey = jax .random .split (key )
27+ initial_centroid_idx = jax .random .choice (subkey , n_obs , (1 ,), replace = False )
28+ initial_centroid = X [initial_centroid_idx ].ravel ()
29+ dist_sq = jnp .square (cdist (initial_centroid [jnp .newaxis , :], X )).ravel ()
30+ initial_state = {"min_dist_sq" : dist_sq , "centroid" : initial_centroid , "key" : key }
31+ n_local_trials = 2 + int (np .log (n_clusters ))
32+
33+ def _step (state , _ ):
34+ prob = state ["min_dist_sq" ] / jnp .sum (state ["min_dist_sq" ])
35+ # note that observations already chosen as centers will have 0 probability
36+ # and will not be chosen again
37+ state ["key" ], subkey = jax .random .split (state ["key" ])
38+ next_centroid_idx_candidates = jax .random .choice (subkey , n_obs , (n_local_trials ,), replace = False , p = prob )
39+ next_centroid_candidates = X [next_centroid_idx_candidates ]
40+ # candidates by observations
41+ dist_sq_candidates = jnp .square (cdist (next_centroid_candidates , X ))
42+ dist_sq_candidates = jnp .minimum (state ["min_dist_sq" ][jnp .newaxis , :], dist_sq_candidates )
43+ candidates_pot = dist_sq_candidates .sum (axis = 1 )
44+
45+ # Decide which candidate is the best
46+ best_candidate = jnp .argmin (candidates_pot )
47+ min_dist_sq = dist_sq_candidates [best_candidate ]
48+ best_candidate = next_centroid_idx_candidates [best_candidate ]
49+
50+ state ["min_dist_sq" ] = min_dist_sq .ravel ()
51+ state ["centroid" ] = X [best_candidate ].ravel ()
52+ return state , state ["centroid" ]
53+
54+ _ , centroids = jax .lax .scan (_step , initial_state , jnp .arange (n_clusters - 1 ))
55+ return centroids
5556
5657
5758@jax .jit
@@ -62,7 +63,7 @@ def _get_dist_labels(X: jnp.ndarray, centroids: jnp.ndarray) -> jnp.ndarray:
6263 return dist , labels
6364
6465
65- class KMeansJax :
66+ class KMeans :
6667 """Jax implementation of :class:`sklearn.cluster.KMeans`.
6768
6869 This implementation is limited to Euclidean distance.
@@ -91,7 +92,7 @@ class KMeansJax:
9192 def __init__ (
9293 self ,
9394 n_clusters : int = 8 ,
94- init : Literal ["k-means++" , "random" ] = "random " ,
95+ init : Literal ["k-means++" , "random" ] = "k-means++ " ,
9596 n_init : int = 10 ,
9697 max_iter : int = 300 ,
9798 tol : float = 1e-4 ,
@@ -122,7 +123,6 @@ def fit(self, X: np.ndarray):
122123 return self
123124
124125 def _fit (self , X : np .ndarray ):
125- self ._pdists = pdist_squareform (X )
126126 all_centroids , all_inertias = jax .lax .map (
127127 lambda key : self ._kmeans_full_run (X , key ), jax .random .split (self .seed , self .n_init )
128128 )
@@ -131,7 +131,6 @@ def _fit(self, X: np.ndarray):
131131 self .inertia_ = get_ndarray (all_inertias [i ])
132132 _ , labels = _get_dist_labels (X , self .cluster_centroids_ )
133133 self .labels_ = get_ndarray (labels )
134- del self ._pdists
135134
136135 @partial (jax .jit , static_argnums = (0 ,))
137136 def _kmeans_full_run (self , X : jnp .ndarray , key : jnp .ndarray ) -> jnp .ndarray :
@@ -169,7 +168,7 @@ def _kmeans_convergence(state):
169168 cond2 = n_iter > self .max_iter
170169 return jnp .logical_or (cond1 , cond2 )[0 ]
171170
172- centroids = self ._initialize (X , self .n_clusters , self . _pdists , key )
171+ centroids = self ._initialize (X , self .n_clusters , key )
173172 # centroids, new_inertia, old_inertia, n_iter
174173 state = (centroids , jnp .inf , jnp .inf , jnp .array ([0.0 ]))
175174 state = _kmeans_step (state )
0 commit comments