Skip to content

Commit c7fe35e

Browse files
committed
Implement masking to control how embedded points are updated
1 parent 42b3f1f commit c7fe35e

File tree

3 files changed

+395
-24
lines changed

3 files changed

+395
-24
lines changed

umap/layouts.py

Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,135 @@ def _optimize_layout_euclidean_single_epoch(
181181
)
182182

183183

184+
def _optimize_layout_euclidean_masked_single_epoch(
185+
head_embedding,
186+
tail_embedding,
187+
head,
188+
tail,
189+
mask,
190+
n_vertices,
191+
epochs_per_sample,
192+
a,
193+
b,
194+
rng_state,
195+
gamma,
196+
dim,
197+
move_other,
198+
alpha,
199+
epochs_per_negative_sample,
200+
epoch_of_next_negative_sample,
201+
epoch_of_next_sample,
202+
n,
203+
densmap_flag,
204+
dens_phi_sum,
205+
dens_re_sum,
206+
dens_re_cov,
207+
dens_re_std,
208+
dens_re_mean,
209+
dens_lambda,
210+
dens_R,
211+
dens_mu,
212+
dens_mu_tot,
213+
):
214+
for i in numba.prange(epochs_per_sample.shape[0]):
215+
if epoch_of_next_sample[i] <= n:
216+
j = head[i]
217+
k = tail[i]
218+
219+
current = head_embedding[j]
220+
other = tail_embedding[k]
221+
222+
current_mask = mask[j]
223+
other_mask = mask[k]
224+
225+
dist_squared = rdist(current, other)
226+
227+
if densmap_flag:
228+
phi = 1.0 / (1.0 + a * pow(dist_squared, b))
229+
dphi_term = (
230+
a * b * pow(dist_squared, b - 1) / (1.0 + a * pow(dist_squared, b))
231+
)
232+
233+
q_jk = phi / dens_phi_sum[k]
234+
q_kj = phi / dens_phi_sum[j]
235+
236+
drk = q_jk * (
237+
(1.0 - b * (1 - phi)) / np.exp(dens_re_sum[k]) + dphi_term
238+
)
239+
drj = q_kj * (
240+
(1.0 - b * (1 - phi)) / np.exp(dens_re_sum[j]) + dphi_term
241+
)
242+
243+
re_std_sq = dens_re_std * dens_re_std
244+
weight_k = (
245+
dens_R[k]
246+
- dens_re_cov * (dens_re_sum[k] - dens_re_mean) / re_std_sq
247+
)
248+
weight_j = (
249+
dens_R[j]
250+
- dens_re_cov * (dens_re_sum[j] - dens_re_mean) / re_std_sq
251+
)
252+
253+
grad_cor_coeff = (
254+
dens_lambda
255+
* dens_mu_tot
256+
* (weight_k * drk + weight_j * drj)
257+
/ (dens_mu[i] * dens_re_std)
258+
/ n_vertices
259+
)
260+
261+
if dist_squared > 0.0:
262+
grad_coeff = -2.0 * a * b * pow(dist_squared, b - 1.0)
263+
grad_coeff /= a * pow(dist_squared, b) + 1.0
264+
else:
265+
grad_coeff = 0.0
266+
267+
for d in range(dim):
268+
grad_d = clip(grad_coeff * (current[d] - other[d]))
269+
270+
if densmap_flag:
271+
grad_d += clip(2 * grad_cor_coeff * (current[d] - other[d]))
272+
273+
current[d] += current_mask * grad_d * alpha
274+
if move_other:
275+
other[d] += - other_mask * grad_d * alpha
276+
277+
epoch_of_next_sample[i] += epochs_per_sample[i]
278+
279+
n_neg_samples = int(
280+
(n - epoch_of_next_negative_sample[i]) / epochs_per_negative_sample[i]
281+
)
282+
283+
for p in range(n_neg_samples):
284+
k = tau_rand_int(rng_state) % n_vertices
285+
286+
other = tail_embedding[k]
287+
288+
dist_squared = rdist(current, other)
289+
290+
if dist_squared > 0.0:
291+
grad_coeff = 2.0 * gamma * b
292+
grad_coeff /= (0.001 + dist_squared) * (
293+
a * pow(dist_squared, b) + 1
294+
)
295+
elif j == k:
296+
continue
297+
else:
298+
grad_coeff = 0.0
299+
300+
for d in range(dim):
301+
if grad_coeff > 0.0:
302+
grad_d = clip(grad_coeff * (current[d] - other[d]))
303+
else:
304+
grad_d = 4.0
305+
current[d] += current_mask * grad_d * alpha
306+
307+
epoch_of_next_negative_sample[i] += (
308+
n_neg_samples * epochs_per_negative_sample[i]
309+
)
310+
311+
312+
184313
def _optimize_layout_euclidean_densmap_epoch_init(
185314
head_embedding, tail_embedding, head, tail, a, b, re_sum, phi_sum,
186315
):
@@ -379,6 +508,184 @@ def optimize_layout_euclidean(
379508
return head_embedding
380509

381510

511+
def optimize_layout_euclidean_masked(
512+
head_embedding,
513+
tail_embedding,
514+
head,
515+
tail,
516+
mask,
517+
n_epochs,
518+
n_vertices,
519+
epochs_per_sample,
520+
a,
521+
b,
522+
rng_state,
523+
gamma=1.0,
524+
initial_alpha=1.0,
525+
negative_sample_rate=5.0,
526+
parallel=False,
527+
verbose=False,
528+
densmap=False,
529+
densmap_kwds={},
530+
):
531+
"""Improve an embedding using stochastic gradient descent to minimize the
532+
fuzzy set cross entropy between the 1-skeletons of the high dimensional
533+
and low dimensional fuzzy simplicial sets. In practice this is done by
534+
sampling edges based on their membership strength (with the (1-p) terms
535+
coming from negative sampling similar to word2vec).
536+
Parameters
537+
----------
538+
head_embedding: array of shape (n_samples, n_components)
539+
The initial embedding to be improved by SGD.
540+
tail_embedding: array of shape (source_samples, n_components)
541+
The reference embedding of embedded points. If not embedding new
542+
previously unseen points with respect to an existing embedding this
543+
is simply the head_embedding (again); otherwise it provides the
544+
existing embedding to embed with respect to.
545+
head: array of shape (n_1_simplices)
546+
The indices of the heads of 1-simplices with non-zero membership.
547+
tail: array of shape (n_1_simplices)
548+
The indices of the tails of 1-simplices with non-zero membership.
549+
mask: array of shape (n_samples)
550+
The weights (in [0,1]) assigned to each sample, defining how much they
551+
should be updated. 0 means the point will not move at all, 1 means
552+
they are updated normally. In-between values allow for fine-tuning.
553+
n_epochs: int
554+
The number of training epochs to use in optimization.
555+
n_vertices: int
556+
The number of vertices (0-simplices) in the dataset.
557+
epochs_per_samples: array of shape (n_1_simplices)
558+
A float value of the number of epochs per 1-simplex. 1-simplices with
559+
weaker membership strength will have more epochs between being sampled.
560+
a: float
561+
Parameter of differentiable approximation of right adjoint functor
562+
b: float
563+
Parameter of differentiable approximation of right adjoint functor
564+
rng_state: array of int64, shape (3,)
565+
The internal state of the rng
566+
gamma: float (optional, default 1.0)
567+
Weight to apply to negative samples.
568+
initial_alpha: float (optional, default 1.0)
569+
Initial learning rate for the SGD.
570+
negative_sample_rate: int (optional, default 5)
571+
Number of negative samples to use per positive sample.
572+
parallel: bool (optional, default False)
573+
Whether to run the computation using numba parallel.
574+
Running in parallel is non-deterministic, and is not used
575+
if a random seed has been set, to ensure reproducibility.
576+
verbose: bool (optional, default False)
577+
Whether to report information on the current progress of the algorithm.
578+
densmap: bool (optional, default False)
579+
Whether to use the density-augmented densMAP objective
580+
densmap_kwds: dict (optional, default {})
581+
Auxiliary data for densMAP
582+
Returns
583+
-------
584+
embedding: array of shape (n_samples, n_components)
585+
The optimized embedding.
586+
"""
587+
588+
dim = head_embedding.shape[1]
589+
move_other = head_embedding.shape[0] == tail_embedding.shape[0]
590+
alpha = initial_alpha
591+
592+
epochs_per_negative_sample = epochs_per_sample / negative_sample_rate
593+
epoch_of_next_negative_sample = epochs_per_negative_sample.copy()
594+
epoch_of_next_sample = epochs_per_sample.copy()
595+
596+
optimize_fn = numba.njit(
597+
_optimize_layout_euclidean_masked_single_epoch, fastmath=True, parallel=parallel
598+
)
599+
600+
if densmap:
601+
dens_init_fn = numba.njit(
602+
_optimize_layout_euclidean_densmap_epoch_init,
603+
fastmath=True,
604+
parallel=parallel,
605+
)
606+
607+
dens_mu_tot = np.sum(densmap_kwds["mu_sum"]) / 2
608+
dens_lambda = densmap_kwds["lambda"]
609+
dens_R = densmap_kwds["R"]
610+
dens_mu = densmap_kwds["mu"]
611+
dens_phi_sum = np.zeros(n_vertices, dtype=np.float32)
612+
dens_re_sum = np.zeros(n_vertices, dtype=np.float32)
613+
dens_var_shift = densmap_kwds["var_shift"]
614+
else:
615+
dens_mu_tot = 0
616+
dens_lambda = 0
617+
dens_R = np.zeros(1, dtype=np.float32)
618+
dens_mu = np.zeros(1, dtype=np.float32)
619+
dens_phi_sum = np.zeros(1, dtype=np.float32)
620+
dens_re_sum = np.zeros(1, dtype=np.float32)
621+
622+
for n in range(n_epochs):
623+
624+
densmap_flag = (
625+
densmap
626+
and (densmap_kwds["lambda"] > 0)
627+
and (((n + 1) / float(n_epochs)) > (1 - densmap_kwds["frac"]))
628+
)
629+
630+
if densmap_flag:
631+
dens_init_fn(
632+
head_embedding,
633+
tail_embedding,
634+
head,
635+
tail,
636+
a,
637+
b,
638+
dens_re_sum,
639+
dens_phi_sum,
640+
)
641+
642+
dens_re_std = np.sqrt(np.var(dens_re_sum) + dens_var_shift)
643+
dens_re_mean = np.mean(dens_re_sum)
644+
dens_re_cov = np.dot(dens_re_sum, dens_R) / (n_vertices - 1)
645+
else:
646+
dens_re_std = 0
647+
dens_re_mean = 0
648+
dens_re_cov = 0
649+
650+
optimize_fn(
651+
head_embedding,
652+
tail_embedding,
653+
head,
654+
tail,
655+
mask,
656+
n_vertices,
657+
epochs_per_sample,
658+
a,
659+
b,
660+
rng_state,
661+
gamma,
662+
dim,
663+
move_other,
664+
alpha,
665+
epochs_per_negative_sample,
666+
epoch_of_next_negative_sample,
667+
epoch_of_next_sample,
668+
n,
669+
densmap_flag,
670+
dens_phi_sum,
671+
dens_re_sum,
672+
dens_re_cov,
673+
dens_re_std,
674+
dens_re_mean,
675+
dens_lambda,
676+
dens_R,
677+
dens_mu,
678+
dens_mu_tot,
679+
)
680+
681+
alpha = initial_alpha * (1.0 - (float(n) / float(n_epochs)))
682+
683+
if verbose and n % int(n_epochs / 10) == 0:
684+
print("\tcompleted ", n, " / ", n_epochs, "epochs")
685+
686+
return head_embedding
687+
688+
382689
@numba.njit(fastmath=True)
383690
def optimize_layout_generic(
384691
head_embedding,

umap/parametric_umap.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def _compile_model(self):
358358
run_eagerly=self.run_eagerly,
359359
)
360360

361-
def _fit_embed_data(self, X, n_epochs, init, random_state):
361+
def _fit_embed_data(self, X, n_epochs, init, random_state, pin_mask):
362362

363363
if self.metric == "precomputed":
364364
X = self._X
@@ -371,6 +371,12 @@ def _fit_embed_data(self, X, n_epochs, init, random_state):
371371
if len(self.dims) > 1:
372372
X = np.reshape(X, [len(X)] + list(self.dims))
373373

374+
if pin_mask is not None:
375+
warn(
376+
"Pinning is not yet supported by Parametric UMAP.\
377+
Ignoring the pin_mask."
378+
)
379+
374380
if self.parametric_reconstruction and (np.max(X) > 1.0 or np.min(X) < 0.0):
375381
warn(
376382
"Data should be scaled to the range 0-1 for cross-entropy reconstruction loss."

0 commit comments

Comments
 (0)