From deff25e8d25ca085dc9d521945c59164fb7eb805 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Fri, 18 Jul 2025 09:22:48 -0700 Subject: [PATCH 1/2] added the groupby logic to inference code. should change to sample with replacement and then truncate --- src/state/_cli/_tx/_infer.py | 169 ++++++++++++++++++++++------------- 1 file changed, 109 insertions(+), 60 deletions(-) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 10b5a92e..f82c88cb 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -13,6 +13,9 @@ def add_arguments_infer(parser: argparse.ArgumentParser): parser.add_argument( "--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels" ) + parser.add_argument( + "--batch_col", type=str, default="batch_var", help="Column in adata.obs batch labels" + ) parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") parser.add_argument( "--model_dir", @@ -27,6 +30,7 @@ def add_arguments_infer(parser: argparse.ArgumentParser): "--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)" ) parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") + parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducible cell shuffling within perturbation groups") def run_tx_infer(args): @@ -38,6 +42,7 @@ def run_tx_infer(args): import scanpy as sc import torch import yaml + from lightning import seed_everything from tqdm import tqdm from ...tx.models.pert_sets import PertSetsPerturbationModel @@ -45,6 +50,11 @@ def run_tx_infer(args): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + # Set random seed if provided + if args.seed is not None: + seed_everything(args.seed) + logger.info(f"Set random seed to {args.seed}") + def load_config(cfg_path: str) -> dict: """Load config from the YAML file that was dumped during training.""" if not os.path.exists(cfg_path): @@ -73,7 +83,7 @@ def load_config(cfg_path: str) -> dict: # Load model logger.info(f"Loading model from checkpoint: {checkpoint_path}") - model = PertSetsPerturbationModel.load_from_checkpoint(checkpoint_path) + model = PertSetsPerturbationModel.load_from_checkpoint(checkpoint_path, strict=False, **cfg['model']['kwargs']) model.eval() cell_sentence_len = model.cell_sentence_len device = next(model.parameters()).device @@ -122,6 +132,44 @@ def load_config(cfg_path: str) -> dict: logger.info(f"AnnData has {len(unique_pert_names)} unique perturbations") logger.info(f"First 10 perturbations in AnnData: {unique_pert_names[:10]}") + # Load batch mapping from torch file + batch_onehot_map_path = os.path.join(args.model_dir, "batch_onehot_map.pkl") + batch_onehot_map = pickle.load(open(batch_onehot_map_path, 'rb')) + + # Create batch name to index mapping + batch_names = adata.obs[args.batch_col].values + batch_name_to_idx = {name: idx for idx, name in enumerate(sorted(batch_onehot_map.keys()))} + + # Prepare batch indices tensor + batch_indices_tensor = torch.zeros(len(batch_names), dtype=torch.long, device="cpu") + + logger.info(f"Data module has {len(batch_onehot_map)} batches in mapping") + logger.info(f"First 10 batches in data module: {list(batch_onehot_map.keys())[:10]}") + + unique_batch_names = sorted(set(batch_names)) + logger.info(f"AnnData has {len(unique_batch_names)} unique batches") + logger.info(f"First 10 batches in AnnData: {unique_batch_names[:10]}") + + # Check overlap for batches + batch_overlap = set(unique_batch_names) & set(batch_onehot_map.keys()) + logger.info(f"Overlap between AnnData and data module batches: {len(batch_overlap)} batches") + if len(batch_overlap) < len(unique_batch_names): + missing_batches = set(unique_batch_names) - set(batch_onehot_map.keys()) + logger.warning(f"Missing batches: {list(missing_batches)[:10]}") + + # Fill batch indices + batch_matched_count = 0 + default_batch_idx = 0 # Use first batch as default + for idx, name in enumerate(batch_names): + if name in batch_name_to_idx: + batch_indices_tensor[idx] = batch_name_to_idx[name] + batch_matched_count += 1 + else: + # Use first available batch as fallback + batch_indices_tensor[idx] = default_batch_idx + + logger.info(f"Matched {batch_matched_count} out of {len(batch_names)} batches") + # Check overlap overlap = set(unique_pert_names) & set(pert_onehot_map.keys()) logger.info(f"Overlap between AnnData and data module: {len(overlap)} perturbations") @@ -151,80 +199,81 @@ def load_config(cfg_path: str) -> dict: logger.info(f"Matched {matched_count} out of {len(pert_names)} perturbations") - # Process in batches with progress bar - # Use cell_sentence_len as batch size since model expects this + # Group cells by perturbation to ensure batches contain cells from same perturbation + import pandas as pd + + # Create a DataFrame for easier grouping + df = pd.DataFrame({"pert_name": pert_names, "index": range(len(pert_names))}) + grouped = df.groupby("pert_name") + n_samples = len(pert_names) - batch_size = cell_sentence_len # Model requires this exact batch size - n_batches = (n_samples + batch_size - 1) // batch_size # Ceiling division - + batch_size = args.batch_size # Use user-specified batch size + logger.info( - f"Running inference on {n_samples} samples in {n_batches} batches of size {batch_size} (model's cell_sentence_len)..." + f"Running inference on {n_samples} samples grouped by perturbation with batch size {batch_size}..." ) all_preds = [] + processed_samples = 0 + batch_idx = 0 with torch.no_grad(): progress_bar = tqdm(total=n_samples, desc="Processing samples", unit="samples") - for batch_idx in range(n_batches): - start_idx = batch_idx * batch_size - end_idx = min(start_idx + batch_size, n_samples) - current_batch_size = end_idx - start_idx - - # Get batch data - X_batch = torch.tensor(X[start_idx:end_idx], dtype=torch.float32).to(device) - pert_batch = pert_tensor[start_idx:end_idx].to(device) - pert_names_batch = pert_names[start_idx:end_idx].tolist() - - # Pad the batch to cell_sentence_len if it's the last incomplete batch - if current_batch_size < cell_sentence_len: - # Pad with zeros for embeddings - padding_size = cell_sentence_len - current_batch_size - X_pad = torch.zeros((padding_size, X_batch.shape[1]), device=device) - X_batch = torch.cat([X_batch, X_pad], dim=0) - - # Pad perturbation tensor with control perturbation - pert_pad = torch.zeros((padding_size, pert_batch.shape[1]), device=device) - if control_pert in pert_onehot_map: - pert_pad[:] = pert_onehot_map[control_pert].to(device) + # Process each perturbation group + for pert_name, group in grouped: + indices = group["index"].values + + # Randomize the order of cells within this perturbation group + if args.seed is not None: + np.random.shuffle(indices) + + group_size = len(indices) + + # Process this perturbation group in batches + for i in range(0, group_size, batch_size): + batch_indices = indices[i:i + batch_size] + current_batch_size = len(batch_indices) + + # Get batch data + X_batch = torch.tensor(X[batch_indices], dtype=torch.float32).to(device) + pert_batch = pert_tensor[batch_indices].to(device) + batch_idx_batch = batch_indices_tensor[batch_indices].to(device) + pert_names_batch = [pert_names[idx] for idx in batch_indices] + + # Prepare batch + batch = { + "ctrl_cell_emb": X_batch, + "pert_emb": pert_batch, + "pert_name": pert_names_batch, + "batch": batch_idx_batch.unsqueeze(0), # Shape: (1, current_batch_size) + } + + # Run inference on batch + batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) + + # Extract predictions from the dictionary returned by predict_step + if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: + # Use gene space predictions (from decoder) + pred_tensor = batch_preds["pert_cell_counts_preds"] else: - pert_pad[:, 0] = 1 # Default to first perturbation - pert_batch = torch.cat([pert_batch, pert_pad], dim=0) - - # Extend perturbation names - pert_names_batch.extend([control_pert] * padding_size) - - # Prepare batch - use same format as working code - batch = { - "ctrl_cell_emb": X_batch, - "pert_emb": pert_batch, # Keep as 2D tensor - "pert_name": pert_names_batch, - "batch": torch.zeros((1, cell_sentence_len), device=device), # Use (1, cell_sentence_len) - } - - # Run inference on batch using padded=False like in working code - batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) - - # Extract predictions from the dictionary returned by predict_step - # Use gene decoder output if available, otherwise use latent predictions - if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: - # Use gene space predictions (from decoder) - pred_tensor = batch_preds["pert_cell_counts_preds"] - else: - # Use latent space predictions - pred_tensor = batch_preds["preds"] + # Use latent space predictions + pred_tensor = batch_preds["preds"] - # Only keep predictions for the actual samples (not padding) - actual_preds = pred_tensor[:current_batch_size] - all_preds.append(actual_preds.cpu().numpy()) + # Store predictions with their original indices to maintain order + batch_preds_with_indices = [(batch_indices[j], pred_tensor[j].cpu().numpy()) for j in range(current_batch_size)] + all_preds.extend(batch_preds_with_indices) - # Update progress bar - progress_bar.update(current_batch_size) + # Update progress bar + progress_bar.update(current_batch_size) + processed_samples += current_batch_size + batch_idx += 1 progress_bar.close() - # Concatenate all predictions - preds_np = np.concatenate(all_preds, axis=0) + # Sort predictions by original index to maintain input order + all_preds.sort(key=lambda x: x[0]) + preds_np = np.array([pred for _, pred in all_preds]) # Save predictions to AnnData adata.X = preds_np From 8e2ce9f6bc835d97fceaba312763f85872b59c34 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Tue, 22 Jul 2025 08:10:31 -0700 Subject: [PATCH 2/2] updated inference script to properly pad the cell set, but then remove fake cells --- pyproject.toml | 4 +- src/state/_cli/_tx/_infer.py | 99 +++++++++++++++++------- src/state/configs/data/perturbation.yaml | 1 + 3 files changed, 74 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3d5d0de8..cd08a131 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arc-state" -version = "0.9.9" +version = "0.9.10" description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts." readme = "README.md" authors = [ @@ -13,7 +13,7 @@ authors = [ requires-python = ">=3.10,<3.13" dependencies = [ "anndata>=0.11.4", - "cell-load>=0.7.5", + "cell-load>=0.7.6", "numpy>=2.2.6", "pandas>=2.2.3", "pyyaml>=6.0.2", diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index f82c88cb..7cfb5c85 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -234,35 +234,78 @@ def load_config(cfg_path: str) -> dict: for i in range(0, group_size, batch_size): batch_indices = indices[i:i + batch_size] current_batch_size = len(batch_indices) - - # Get batch data - X_batch = torch.tensor(X[batch_indices], dtype=torch.float32).to(device) - pert_batch = pert_tensor[batch_indices].to(device) - batch_idx_batch = batch_indices_tensor[batch_indices].to(device) - pert_names_batch = [pert_names[idx] for idx in batch_indices] - - # Prepare batch - batch = { - "ctrl_cell_emb": X_batch, - "pert_emb": pert_batch, - "pert_name": pert_names_batch, - "batch": batch_idx_batch.unsqueeze(0), # Shape: (1, current_batch_size) - } - - # Run inference on batch - batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) - - # Extract predictions from the dictionary returned by predict_step - if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: - # Use gene space predictions (from decoder) - pred_tensor = batch_preds["pert_cell_counts_preds"] + + # Check if this is an incomplete batch that needs sampling with replacement + if current_batch_size < batch_size: + # Sample with replacement to fill out the batch + additional_samples_needed = batch_size - current_batch_size + replacement_indices = np.random.choice(batch_indices, size=additional_samples_needed, replace=True) + + # Combine original indices with replacement indices + extended_batch_indices = np.concatenate([batch_indices, replacement_indices]) + + # Get batch data for the extended batch + X_batch = torch.tensor(X[extended_batch_indices], dtype=torch.float32).to(device) + pert_batch = pert_tensor[extended_batch_indices].to(device) + batch_idx_batch = batch_indices_tensor[extended_batch_indices].to(device) + pert_names_batch = [pert_names[idx] for idx in extended_batch_indices] + + # Prepare batch + batch = { + "ctrl_cell_emb": X_batch, + "pert_emb": pert_batch, + "pert_name": pert_names_batch, + "batch": batch_idx_batch.unsqueeze(0), # Shape: (1, batch_size) + } + + # Run inference on batch + batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) + + # Extract predictions from the dictionary returned by predict_step + if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: + # Use gene space predictions (from decoder) + pred_tensor = batch_preds["pert_cell_counts_preds"] + else: + # Use latent space predictions + pred_tensor = batch_preds["preds"] + + # Only keep predictions for the original samples (not the replacement samples) + original_preds = pred_tensor[:current_batch_size] + + # Store predictions with their original indices to maintain order + batch_preds_with_indices = [(batch_indices[j], original_preds[j].cpu().numpy()) for j in range(current_batch_size)] + all_preds.extend(batch_preds_with_indices) + else: - # Use latent space predictions - pred_tensor = batch_preds["preds"] - - # Store predictions with their original indices to maintain order - batch_preds_with_indices = [(batch_indices[j], pred_tensor[j].cpu().numpy()) for j in range(current_batch_size)] - all_preds.extend(batch_preds_with_indices) + # Full batch - process normally + # Get batch data + X_batch = torch.tensor(X[batch_indices], dtype=torch.float32).to(device) + pert_batch = pert_tensor[batch_indices].to(device) + batch_idx_batch = batch_indices_tensor[batch_indices].to(device) + pert_names_batch = [pert_names[idx] for idx in batch_indices] + + # Prepare batch + batch = { + "ctrl_cell_emb": X_batch, + "pert_emb": pert_batch, + "pert_name": pert_names_batch, + "batch": batch_idx_batch.unsqueeze(0), # Shape: (1, current_batch_size) + } + + # Run inference on batch + batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) + + # Extract predictions from the dictionary returned by predict_step + if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: + # Use gene space predictions (from decoder) + pred_tensor = batch_preds["pert_cell_counts_preds"] + else: + # Use latent space predictions + pred_tensor = batch_preds["preds"] + + # Store predictions with their original indices to maintain order + batch_preds_with_indices = [(batch_indices[j], pred_tensor[j].cpu().numpy()) for j in range(current_batch_size)] + all_preds.extend(batch_preds_with_indices) # Update progress bar progress_bar.update(current_batch_size) diff --git a/src/state/configs/data/perturbation.yaml b/src/state/configs/data/perturbation.yaml index 0a2f32ed..1bc253ff 100644 --- a/src/state/configs/data/perturbation.yaml +++ b/src/state/configs/data/perturbation.yaml @@ -19,5 +19,6 @@ kwargs: store_raw_basal: false int_counts: false barcode: true + zeroshot_controls: false output_dir: null debug: true