3838 csr_unique ,
3939 fast_knn_indices ,
4040)
41- from umap .spectral import spectral_layout
41+ from umap .spectral import spectral_layout , tswspectral_layout
4242from umap .layouts import (
4343 optimize_layout_euclidean ,
4444 optimize_layout_generic ,
@@ -1115,6 +1115,18 @@ def simplicial_set_embedding(
11151115 embedding = noisy_scale_coords (
11161116 embedding , random_state , max_coord = 10 , noise = 0.0001
11171117 )
1118+ elif isinstance (init , str ) and init == "tswspectral" :
1119+ embedding = tswspectral_layout (
1120+ data ,
1121+ graph ,
1122+ n_components ,
1123+ random_state ,
1124+ metric = metric ,
1125+ metric_kwds = metric_kwds ,
1126+ )
1127+ embedding = noisy_scale_coords (
1128+ embedding , random_state , max_coord = 10 , noise = 0.0001
1129+ )
11181130 else :
11191131 init_data = np .array (init )
11201132 if len (init_data .shape ) == 2 :
@@ -1459,7 +1471,13 @@ class UMAP(BaseEstimator):
14591471
14601472 * 'spectral': use a spectral embedding of the fuzzy 1-skeleton
14611473 * 'random': assign initial embedding positions at random.
1462- * 'pca': use the first n_components from PCA applied to the input data.
1474+ * 'pca': use the first n_components from PCA applied to the
1475+ input data.
1476+ * 'tswspectral': use a spectral embedding of the fuzzy
1477+ 1-skeleton, using a truncated singular value decomposition to
1478+ "warm" up the eigensolver. This is intended as an alternative
1479+ to the 'spectral' method, if that takes an excessively long
1480+ time to complete initialization (or fails to complete).
14631481 * A numpy array of initial embedding positions.
14641482
14651483 min_dist: float (optional, default 0.1)
@@ -1738,8 +1756,12 @@ def _validate_parameters(self):
17381756 "pca" ,
17391757 "spectral" ,
17401758 "random" ,
1759+ "tswspectral" ,
17411760 ):
1742- raise ValueError ('string init values must be "pca", "spectral" or "random"' )
1761+ raise ValueError (
1762+ 'string init values must be one of: "pca", "tswspectral",'
1763+ ' "spectral" or "random"'
1764+ )
17431765 if (
17441766 isinstance (self .init , np .ndarray )
17451767 and self .init .shape [1 ] != self .n_components
@@ -1769,18 +1791,26 @@ def _validate_parameters(self):
17691791 if self .n_components < 1 :
17701792 raise ValueError ("n_components must be greater than 0" )
17711793 self .n_epochs_list = None
1772- if isinstance (self .n_epochs , list ) or isinstance (self .n_epochs , tuple ) or \
1773- isinstance (self .n_epochs , np .ndarray ):
1774- if not issubclass (np .array (self .n_epochs ).dtype .type , np .integer ) or \
1775- not np .all (np .array (self .n_epochs ) >= 0 ):
1776- raise ValueError ("n_epochs must be a nonnegative integer "
1777- "or a list of nonnegative integers" )
1794+ if (
1795+ isinstance (self .n_epochs , list )
1796+ or isinstance (self .n_epochs , tuple )
1797+ or isinstance (self .n_epochs , np .ndarray )
1798+ ):
1799+ if not issubclass (
1800+ np .array (self .n_epochs ).dtype .type , np .integer
1801+ ) or not np .all (np .array (self .n_epochs ) >= 0 ):
1802+ raise ValueError (
1803+ "n_epochs must be a nonnegative integer "
1804+ "or a list of nonnegative integers"
1805+ )
17781806 self .n_epochs_list = list (self .n_epochs )
17791807 elif self .n_epochs is not None and (
1780- self .n_epochs < 0 or not isinstance (self .n_epochs , int )
1808+ self .n_epochs < 0 or not isinstance (self .n_epochs , int )
17811809 ):
1782- raise ValueError ("n_epochs must be a nonnegative integer "
1783- "or a list of nonnegative integers" )
1810+ raise ValueError (
1811+ "n_epochs must be a nonnegative integer "
1812+ "or a list of nonnegative integers"
1813+ )
17841814 if self .metric_kwds is None :
17851815 self ._metric_kwds = {}
17861816 else :
@@ -2742,7 +2772,9 @@ def fit(self, X, y=None, force_all_finite=True):
27422772 print (ts (), "Construct embedding" )
27432773
27442774 if self .transform_mode == "embedding" :
2745- epochs = self .n_epochs_list if self .n_epochs_list is not None else self .n_epochs
2775+ epochs = (
2776+ self .n_epochs_list if self .n_epochs_list is not None else self .n_epochs
2777+ )
27462778 self .embedding_ , aux_data = self ._fit_embed_data (
27472779 self ._raw_data [index ],
27482780 epochs ,
@@ -2752,11 +2784,15 @@ def fit(self, X, y=None, force_all_finite=True):
27522784
27532785 if self .n_epochs_list is not None :
27542786 if "embedding_list" not in aux_data :
2755- raise KeyError ("No list of embedding were found in 'aux_data'. "
2756- "It is likely the layout optimization function "
2757- "doesn't support the list of int for 'n_epochs'." )
2787+ raise KeyError (
2788+ "No list of embedding were found in 'aux_data'. "
2789+ "It is likely the layout optimization function "
2790+ "doesn't support the list of int for 'n_epochs'."
2791+ )
27582792 else :
2759- self .embedding_list_ = [e [inverse ] for e in aux_data ["embedding_list" ]]
2793+ self .embedding_list_ = [
2794+ e [inverse ] for e in aux_data ["embedding_list" ]
2795+ ]
27602796
27612797 # Assign any points that are fully disconnected from our manifold(s) to have embedding
27622798 # coordinates of np.nan. These will be filtered by our plotting functions automatically.
0 commit comments