Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/abaco/ABaCo.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ def sample(self, sample_shape=torch.Size()):
dm_sample:
Sample(s) from the Dirichlet Multinomial distribution.
"""
shape = self._extended_shape(sample_shape)
# shape = self._extended_shape(sample_shape)
p = td.Dirichlet(self.concentration).sample(sample_shape)

batch_dims = p.shape[:-1]
Expand Down Expand Up @@ -1207,7 +1207,7 @@ def kl_div_loss(self, x):
KL-divergence loss
"""
q = self.encoder(x)
z = q.rsample()
# z = q.rsample()
kl_loss = torch.mean(
self.beta * td.kl_divergence(q, self.prior()),
dim=0,
Expand Down Expand Up @@ -1359,7 +1359,7 @@ def elbo(self, x):

def kl_div_loss(self, x):
q = self.encoder(x)
z = q.rsample()
# z = q.rsample()
kl_loss = torch.mean(
self.beta * td.kl_divergence(q, self.prior()),
dim=0,
Expand Down Expand Up @@ -2035,7 +2035,7 @@ def train_abaco(
for loader_data in data_iter:
x = loader_data[0].to(device)
y = loader_data[1].to(device).float() # Batch label
z = loader_data[2].to(device).float() # Bio type label
# z = loader_data[2].to(device).float() # Bio type label

# VAE ELBO computation with masked batch label
vae_optim_post.zero_grad()
Expand All @@ -2050,8 +2050,8 @@ def train_abaco(
p_xz = vae.decoder(torch.cat([latent_points, alpha * y], dim=1))

# Log probabilities of prior and posterior
log_q_zx = q_zx.log_prob(latent_points)
log_p_z = vae.log_prob(latent_points)
# log_q_zx = q_zx.log_prob(latent_points)
# log_p_z = vae.log_prob(latent_points)

# Compute ELBO
recon_term = p_xz.log_prob(x).mean()
Expand Down Expand Up @@ -2829,7 +2829,7 @@ def train_abaco_ensemble(
for loader_data in data_iter:
x = loader_data[0].to(device)
y = loader_data[1].to(device).float() # Batch label
z = loader_data[2].to(device).float() # Bio type label
# z = loader_data[2].to(device).float() # Bio type label

# VAE ELBO computation with masked batch label
vae_optim_post.zero_grad()
Expand All @@ -2849,8 +2849,8 @@ def train_abaco_ensemble(
p_xzs.append(p_xz)

# Log probabilities of prior and posterior
log_q_zx = q_zx.log_prob(latent_points)
log_p_z = vae.log_prob(latent_points)
# log_q_zx = q_zx.log_prob(latent_points)
# log_p_z = vae.log_prob(latent_points)

# Compute ELBO

Expand Down Expand Up @@ -4529,7 +4529,7 @@ def correct(
for loader_data in iter(self.dataloader):
x = loader_data[0].to(self.device)
ohe_batch = loader_data[1].to(self.device).float() # Batch label
ohe_bio = loader_data[2].to(self.device).float() # Bio type label
# ohe_bio = loader_data[2].to(self.device).float() # Bio type label

# Encode and decode the input data along with the one-hot encoded batch label
q_zx = self.vae.encoder(torch.cat([x, ohe_batch], dim=1)) # td.Distribution
Expand Down
2 changes: 1 addition & 1 deletion src/abaco/batch_correction.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def fit(self, df, y=None):
cc = [self._col_order.index(c) for c in self.covariate_cols]
design_idx = bc + cc
Xd = X_full[:, design_idx]
Xd_ref = X_ref[:, design_idx]
# Xd_ref = X_ref[:, design_idx]

n, p = X_full.shape
feat_idx = list(range(len(design_idx), p))
Expand Down