Skip to content

Commit 43f8d3d

Browse files
authored
Merge pull request #327 from andrewwarrington/main
Non-determinism in SKLearn KMeans initialisation
2 parents fa4723f + d4fe221 commit 43f8d3d

File tree

6 files changed

+33
-11
lines changed

6 files changed

+33
-11
lines changed

dynamax/hidden_markov_model/models/arhmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ def initialize(self,
4242
if method.lower() == "kmeans":
4343
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
4444
from sklearn.cluster import KMeans
45-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
45+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
46+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
47+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
4648
_emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.emission_dim * self.num_lags))
4749
_emission_biases = jnp.array(km.cluster_centers_)
4850
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))

dynamax/hidden_markov_model/models/gamma_hmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def initialize(self,
3939
if method.lower() == "kmeans":
4040
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
4141
from sklearn.cluster import KMeans
42-
km = KMeans(self.num_states).fit(emissions.reshape(-1, 1))
42+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
43+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
44+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, 1))
4345

4446
_emission_concentrations = jnp.ones((self.num_states,))
4547
_emission_rates = jnp.ravel(1.0 / km.cluster_centers_)

dynamax/hidden_markov_model/models/gaussian_hmm.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,9 @@ def initialize(self, key=jr.PRNGKey(0),
7171
if method.lower() == "kmeans":
7272
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
7373
from sklearn.cluster import KMeans
74-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
74+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
75+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
76+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
7577

7678
_emission_means = jnp.array(km.cluster_centers_)
7779
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))
@@ -167,7 +169,9 @@ def initialize(self, key=jr.PRNGKey(0),
167169
if method.lower() == "kmeans":
168170
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
169171
from sklearn.cluster import KMeans
170-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
172+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
173+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
174+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
171175
_emission_means = jnp.array(km.cluster_centers_)
172176
_emission_scale_diags = jnp.ones((self.num_states, self.emission_dim))
173177

@@ -286,7 +290,9 @@ def initialize(self, key=jr.PRNGKey(0),
286290
if method.lower() == "kmeans":
287291
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
288292
from sklearn.cluster import KMeans
289-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
293+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
294+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
295+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
290296
_emission_means = jnp.array(km.cluster_centers_)
291297
_emission_scales = jnp.ones((self.num_states,))
292298

@@ -386,7 +392,9 @@ def initialize(self, key=jr.PRNGKey(0),
386392
if method.lower() == "kmeans":
387393
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
388394
from sklearn.cluster import KMeans
389-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
395+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
396+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
397+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
390398
_emission_means = jnp.array(km.cluster_centers_)
391399
_emission_cov = jnp.eye(self.emission_dim)
392400

@@ -506,7 +514,9 @@ def initialize(self, key=jr.PRNGKey(0),
506514
if method.lower() == "kmeans":
507515
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
508516
from sklearn.cluster import KMeans
509-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
517+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
518+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
519+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
510520
_emission_means = jnp.array(km.cluster_centers_)
511521
_emission_cov_diag_factors = jnp.ones((self.num_states, self.emission_dim))
512522
_emission_cov_low_rank_factors = jnp.zeros((self.num_states, self.emission_dim, self.emission_rank))

dynamax/hidden_markov_model/models/gmm_hmm.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def initialize(self, key=jr.PRNGKey(0),
7878
if method.lower() == "kmeans":
7979
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
8080
from sklearn.cluster import KMeans
81-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
81+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
82+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
83+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
8284
_emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
8385
_emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
8486
_emission_covs = jnp.tile(jnp.eye(self.emission_dim), (self.num_states, self.num_components, 1, 1))
@@ -298,7 +300,9 @@ def initialize(self, key=jr.PRNGKey(0),
298300
if method.lower() == "kmeans":
299301
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
300302
from sklearn.cluster import KMeans
301-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
303+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
304+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
305+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
302306
_emission_weights = jnp.ones((self.num_states, self.num_components)) / self.num_components
303307
_emission_means = jnp.tile(jnp.array(km.cluster_centers_)[:, None, :], (1, self.num_components, 1))
304308
_emission_scale_diags = jnp.ones((self.num_states, self.num_components, self.emission_dim))

dynamax/hidden_markov_model/models/linreg_hmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def initialize(self,
5959
if method.lower() == "kmeans":
6060
assert emissions is not None, "Need emissions to initialize the model with K-Means!"
6161
from sklearn.cluster import KMeans
62-
km = KMeans(self.num_states).fit(emissions.reshape(-1, self.emission_dim))
62+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
63+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
64+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(emissions.reshape(-1, self.emission_dim))
6365
_emission_weights = jnp.zeros((self.num_states, self.emission_dim, self.input_dim))
6466
_emission_biases = jnp.array(km.cluster_centers_)
6567
_emission_covs = jnp.tile(jnp.eye(self.emission_dim)[None, :, :], (self.num_states, 1, 1))

dynamax/hidden_markov_model/models/logreg_hmm.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ def initialize(self,
5252

5353
flat_emissions = emissions.reshape(-1,)
5454
flat_inputs = inputs.reshape(-1, self.input_dim)
55-
km = KMeans(self.num_states).fit(flat_inputs)
55+
key, subkey = jr.split(key) # Create a random seed for SKLearn.
56+
sklearn_key = jr.randint(subkey, shape=(), minval=0, maxval=2147483647) # Max int32 value.
57+
km = KMeans(self.num_states, random_state=int(sklearn_key)).fit(flat_inputs)
5658
_emission_weights = jnp.zeros((self.num_states, self.input_dim))
5759
_emission_biases = jnp.array([tfb.Sigmoid().inverse(flat_emissions[km.labels_ == k].mean())
5860
for k in range(self.num_states)])

0 commit comments

Comments
 (0)