@@ -181,6 +181,135 @@ def _optimize_layout_euclidean_single_epoch(
181181 )
182182
183183
184+ def _optimize_layout_euclidean_masked_single_epoch (
185+ head_embedding ,
186+ tail_embedding ,
187+ head ,
188+ tail ,
189+ mask ,
190+ n_vertices ,
191+ epochs_per_sample ,
192+ a ,
193+ b ,
194+ rng_state ,
195+ gamma ,
196+ dim ,
197+ move_other ,
198+ alpha ,
199+ epochs_per_negative_sample ,
200+ epoch_of_next_negative_sample ,
201+ epoch_of_next_sample ,
202+ n ,
203+ densmap_flag ,
204+ dens_phi_sum ,
205+ dens_re_sum ,
206+ dens_re_cov ,
207+ dens_re_std ,
208+ dens_re_mean ,
209+ dens_lambda ,
210+ dens_R ,
211+ dens_mu ,
212+ dens_mu_tot ,
213+ ):
214+ for i in numba .prange (epochs_per_sample .shape [0 ]):
215+ if epoch_of_next_sample [i ] <= n :
216+ j = head [i ]
217+ k = tail [i ]
218+
219+ current = head_embedding [j ]
220+ other = tail_embedding [k ]
221+
222+ current_mask = mask [j ]
223+ other_mask = mask [k ]
224+
225+ dist_squared = rdist (current , other )
226+
227+ if densmap_flag :
228+ phi = 1.0 / (1.0 + a * pow (dist_squared , b ))
229+ dphi_term = (
230+ a * b * pow (dist_squared , b - 1 ) / (1.0 + a * pow (dist_squared , b ))
231+ )
232+
233+ q_jk = phi / dens_phi_sum [k ]
234+ q_kj = phi / dens_phi_sum [j ]
235+
236+ drk = q_jk * (
237+ (1.0 - b * (1 - phi )) / np .exp (dens_re_sum [k ]) + dphi_term
238+ )
239+ drj = q_kj * (
240+ (1.0 - b * (1 - phi )) / np .exp (dens_re_sum [j ]) + dphi_term
241+ )
242+
243+ re_std_sq = dens_re_std * dens_re_std
244+ weight_k = (
245+ dens_R [k ]
246+ - dens_re_cov * (dens_re_sum [k ] - dens_re_mean ) / re_std_sq
247+ )
248+ weight_j = (
249+ dens_R [j ]
250+ - dens_re_cov * (dens_re_sum [j ] - dens_re_mean ) / re_std_sq
251+ )
252+
253+ grad_cor_coeff = (
254+ dens_lambda
255+ * dens_mu_tot
256+ * (weight_k * drk + weight_j * drj )
257+ / (dens_mu [i ] * dens_re_std )
258+ / n_vertices
259+ )
260+
261+ if dist_squared > 0.0 :
262+ grad_coeff = - 2.0 * a * b * pow (dist_squared , b - 1.0 )
263+ grad_coeff /= a * pow (dist_squared , b ) + 1.0
264+ else :
265+ grad_coeff = 0.0
266+
267+ for d in range (dim ):
268+ grad_d = clip (grad_coeff * (current [d ] - other [d ]))
269+
270+ if densmap_flag :
271+ grad_d += clip (2 * grad_cor_coeff * (current [d ] - other [d ]))
272+
273+ current [d ] += current_mask * grad_d * alpha
274+ if move_other :
275+ other [d ] += - other_mask * grad_d * alpha
276+
277+ epoch_of_next_sample [i ] += epochs_per_sample [i ]
278+
279+ n_neg_samples = int (
280+ (n - epoch_of_next_negative_sample [i ]) / epochs_per_negative_sample [i ]
281+ )
282+
283+ for p in range (n_neg_samples ):
284+ k = tau_rand_int (rng_state ) % n_vertices
285+
286+ other = tail_embedding [k ]
287+
288+ dist_squared = rdist (current , other )
289+
290+ if dist_squared > 0.0 :
291+ grad_coeff = 2.0 * gamma * b
292+ grad_coeff /= (0.001 + dist_squared ) * (
293+ a * pow (dist_squared , b ) + 1
294+ )
295+ elif j == k :
296+ continue
297+ else :
298+ grad_coeff = 0.0
299+
300+ for d in range (dim ):
301+ if grad_coeff > 0.0 :
302+ grad_d = clip (grad_coeff * (current [d ] - other [d ]))
303+ else :
304+ grad_d = 4.0
305+ current [d ] += current_mask * grad_d * alpha
306+
307+ epoch_of_next_negative_sample [i ] += (
308+ n_neg_samples * epochs_per_negative_sample [i ]
309+ )
310+
311+
312+
184313def _optimize_layout_euclidean_densmap_epoch_init (
185314 head_embedding , tail_embedding , head , tail , a , b , re_sum , phi_sum ,
186315):
@@ -379,6 +508,184 @@ def optimize_layout_euclidean(
379508 return head_embedding
380509
381510
511+ def optimize_layout_euclidean_masked (
512+ head_embedding ,
513+ tail_embedding ,
514+ head ,
515+ tail ,
516+ mask ,
517+ n_epochs ,
518+ n_vertices ,
519+ epochs_per_sample ,
520+ a ,
521+ b ,
522+ rng_state ,
523+ gamma = 1.0 ,
524+ initial_alpha = 1.0 ,
525+ negative_sample_rate = 5.0 ,
526+ parallel = False ,
527+ verbose = False ,
528+ densmap = False ,
529+ densmap_kwds = {},
530+ ):
531+ """Improve an embedding using stochastic gradient descent to minimize the
532+ fuzzy set cross entropy between the 1-skeletons of the high dimensional
533+ and low dimensional fuzzy simplicial sets. In practice this is done by
534+ sampling edges based on their membership strength (with the (1-p) terms
535+ coming from negative sampling similar to word2vec).
536+ Parameters
537+ ----------
538+ head_embedding: array of shape (n_samples, n_components)
539+ The initial embedding to be improved by SGD.
540+ tail_embedding: array of shape (source_samples, n_components)
541+ The reference embedding of embedded points. If not embedding new
542+ previously unseen points with respect to an existing embedding this
543+ is simply the head_embedding (again); otherwise it provides the
544+ existing embedding to embed with respect to.
545+ head: array of shape (n_1_simplices)
546+ The indices of the heads of 1-simplices with non-zero membership.
547+ tail: array of shape (n_1_simplices)
548+ The indices of the tails of 1-simplices with non-zero membership.
549+ mask: array of shape (n_samples)
550+ The weights (in [0,1]) assigned to each sample, defining how much they
551+ should be updated. 0 means the point will not move at all, 1 means
552+ they are updated normally. In-between values allow for fine-tuning.
553+ n_epochs: int
554+ The number of training epochs to use in optimization.
555+ n_vertices: int
556+ The number of vertices (0-simplices) in the dataset.
557+ epochs_per_samples: array of shape (n_1_simplices)
558+ A float value of the number of epochs per 1-simplex. 1-simplices with
559+ weaker membership strength will have more epochs between being sampled.
560+ a: float
561+ Parameter of differentiable approximation of right adjoint functor
562+ b: float
563+ Parameter of differentiable approximation of right adjoint functor
564+ rng_state: array of int64, shape (3,)
565+ The internal state of the rng
566+ gamma: float (optional, default 1.0)
567+ Weight to apply to negative samples.
568+ initial_alpha: float (optional, default 1.0)
569+ Initial learning rate for the SGD.
570+ negative_sample_rate: int (optional, default 5)
571+ Number of negative samples to use per positive sample.
572+ parallel: bool (optional, default False)
573+ Whether to run the computation using numba parallel.
574+ Running in parallel is non-deterministic, and is not used
575+ if a random seed has been set, to ensure reproducibility.
576+ verbose: bool (optional, default False)
577+ Whether to report information on the current progress of the algorithm.
578+ densmap: bool (optional, default False)
579+ Whether to use the density-augmented densMAP objective
580+ densmap_kwds: dict (optional, default {})
581+ Auxiliary data for densMAP
582+ Returns
583+ -------
584+ embedding: array of shape (n_samples, n_components)
585+ The optimized embedding.
586+ """
587+
588+ dim = head_embedding .shape [1 ]
589+ move_other = head_embedding .shape [0 ] == tail_embedding .shape [0 ]
590+ alpha = initial_alpha
591+
592+ epochs_per_negative_sample = epochs_per_sample / negative_sample_rate
593+ epoch_of_next_negative_sample = epochs_per_negative_sample .copy ()
594+ epoch_of_next_sample = epochs_per_sample .copy ()
595+
596+ optimize_fn = numba .njit (
597+ _optimize_layout_euclidean_masked_single_epoch , fastmath = True , parallel = parallel
598+ )
599+
600+ if densmap :
601+ dens_init_fn = numba .njit (
602+ _optimize_layout_euclidean_densmap_epoch_init ,
603+ fastmath = True ,
604+ parallel = parallel ,
605+ )
606+
607+ dens_mu_tot = np .sum (densmap_kwds ["mu_sum" ]) / 2
608+ dens_lambda = densmap_kwds ["lambda" ]
609+ dens_R = densmap_kwds ["R" ]
610+ dens_mu = densmap_kwds ["mu" ]
611+ dens_phi_sum = np .zeros (n_vertices , dtype = np .float32 )
612+ dens_re_sum = np .zeros (n_vertices , dtype = np .float32 )
613+ dens_var_shift = densmap_kwds ["var_shift" ]
614+ else :
615+ dens_mu_tot = 0
616+ dens_lambda = 0
617+ dens_R = np .zeros (1 , dtype = np .float32 )
618+ dens_mu = np .zeros (1 , dtype = np .float32 )
619+ dens_phi_sum = np .zeros (1 , dtype = np .float32 )
620+ dens_re_sum = np .zeros (1 , dtype = np .float32 )
621+
622+ for n in range (n_epochs ):
623+
624+ densmap_flag = (
625+ densmap
626+ and (densmap_kwds ["lambda" ] > 0 )
627+ and (((n + 1 ) / float (n_epochs )) > (1 - densmap_kwds ["frac" ]))
628+ )
629+
630+ if densmap_flag :
631+ dens_init_fn (
632+ head_embedding ,
633+ tail_embedding ,
634+ head ,
635+ tail ,
636+ a ,
637+ b ,
638+ dens_re_sum ,
639+ dens_phi_sum ,
640+ )
641+
642+ dens_re_std = np .sqrt (np .var (dens_re_sum ) + dens_var_shift )
643+ dens_re_mean = np .mean (dens_re_sum )
644+ dens_re_cov = np .dot (dens_re_sum , dens_R ) / (n_vertices - 1 )
645+ else :
646+ dens_re_std = 0
647+ dens_re_mean = 0
648+ dens_re_cov = 0
649+
650+ optimize_fn (
651+ head_embedding ,
652+ tail_embedding ,
653+ head ,
654+ tail ,
655+ mask ,
656+ n_vertices ,
657+ epochs_per_sample ,
658+ a ,
659+ b ,
660+ rng_state ,
661+ gamma ,
662+ dim ,
663+ move_other ,
664+ alpha ,
665+ epochs_per_negative_sample ,
666+ epoch_of_next_negative_sample ,
667+ epoch_of_next_sample ,
668+ n ,
669+ densmap_flag ,
670+ dens_phi_sum ,
671+ dens_re_sum ,
672+ dens_re_cov ,
673+ dens_re_std ,
674+ dens_re_mean ,
675+ dens_lambda ,
676+ dens_R ,
677+ dens_mu ,
678+ dens_mu_tot ,
679+ )
680+
681+ alpha = initial_alpha * (1.0 - (float (n ) / float (n_epochs )))
682+
683+ if verbose and n % int (n_epochs / 10 ) == 0 :
684+ print ("\t completed " , n , " / " , n_epochs , "epochs" )
685+
686+ return head_embedding
687+
688+
382689@numba .njit (fastmath = True )
383690def optimize_layout_generic (
384691 head_embedding ,
0 commit comments