diff --git a/README.md b/README.md index e4d9d77f..4377ca2b 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,12 @@ An example evaluation command for a sets model: state-sets sets predict --output_dir /home/aadduri/state-sets/test/ --checkpoint last.ckpt ``` +An example inference command for a sets model: + +```bash +state-sets sets infer --output /home/dhruvgautam/state-sets/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/checkpoints/last.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg +``` + The toml files should be setup to define perturbation splits, if running fewshot experiments. Here are some examples: ```toml diff --git a/pyproject.toml b/pyproject.toml index a88d814d..f234664a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "state-sets" -version = "0.6.5" +version = "0.8.0" description = "Add your description here" readme = "README.md" authors = [] @@ -28,8 +28,8 @@ dependencies = [ dev = ["ruff>=0.11.11", "vulture>=2.14"] [tool.uv.sources] -cell-load = { git = "ssh://github.com/arcinstitute/cell-load.git" } -cell-eval = { git = "ssh://github.com/arcinstitute/cell-eval.git", branch = "pin_with_pr_curves" } +cell-load = { git = "https://github.com/arcinstitute/cell-load.git" } +cell-eval = { git = "https://github.com/arcinstitute/cell-eval.git" } zclip = { git = "https://github.com/bluorion-com/ZClip.git" } [build-system] diff --git a/src/state_sets/__main__.py b/src/state_sets/__main__.py index ea4a7cbb..f0d27803 100644 --- a/src/state_sets/__main__.py +++ b/src/state_sets/__main__.py @@ -6,6 +6,7 @@ add_arguments_state, run_sets_predict, run_sets_train, + run_sets_infer, run_state_embed, run_state_train, ) @@ -60,6 +61,9 @@ def main(): case "predict": # For now, predict uses argparse and not hydra run_sets_predict(args) + case "infer": + # Run inference using argparse, similar to predict + run_sets_infer(args) if __name__ == "__main__": diff --git a/src/state_sets/_cli/__init__.py b/src/state_sets/_cli/__init__.py index 947c4960..5e057c70 100644 --- a/src/state_sets/_cli/__init__.py +++ b/src/state_sets/_cli/__init__.py @@ -1,4 +1,4 @@ -from ._sets import add_arguments_sets, run_sets_predict, run_sets_train +from ._sets import add_arguments_sets, run_sets_predict, run_sets_train, run_sets_infer from ._state import add_arguments_state, run_state_embed, run_state_train __all__ = [ @@ -7,5 +7,6 @@ "run_sets_train", "run_state_embed", "run_sets_predict", + "run_sets_infer", "run_state_train", ] diff --git a/src/state_sets/_cli/_sets/__init__.py b/src/state_sets/_cli/_sets/__init__.py index aede1b93..d64471bb 100644 --- a/src/state_sets/_cli/_sets/__init__.py +++ b/src/state_sets/_cli/_sets/__init__.py @@ -2,8 +2,9 @@ from ._predict import add_arguments_predict, run_sets_predict from ._train import add_arguments_train, run_sets_train +from ._infer import add_arguments_infer, run_sets_infer -__all__ = ["run_sets_train", "run_sets_predict", "add_arguments_sets"] +__all__ = ["run_sets_train", "run_sets_predict", "run_sets_infer", "add_arguments_sets"] def add_arguments_sets(parser: ap.ArgumentParser): @@ -11,3 +12,4 @@ def add_arguments_sets(parser: ap.ArgumentParser): subparsers = parser.add_subparsers(required=True, dest="subcommand") add_arguments_train(subparsers.add_parser("train")) add_arguments_predict(subparsers.add_parser("predict")) + add_arguments_infer(subparsers.add_parser("infer")) diff --git a/src/state_sets/_cli/_sets/_infer.py b/src/state_sets/_cli/_sets/_infer.py new file mode 100644 index 00000000..b79faf80 --- /dev/null +++ b/src/state_sets/_cli/_sets/_infer.py @@ -0,0 +1,234 @@ +import argparse +import scanpy as sc +import torch +import numpy as np +import os +import pandas as pd +from tqdm import tqdm +import yaml + +from ...sets.models.pert_sets import PertSetsPerturbationModel +from cell_load.data_modules import PerturbationDataModule + + +def add_arguments_infer(parser: argparse.ArgumentParser): + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)") + parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") + parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features") + parser.add_argument( + "--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels" + ) + parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to the output_dir containing the config.yaml file that was saved during training.", + ) + parser.add_argument( + "--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)" + ) + parser.add_argument( + "--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)" + ) + parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") + + +def run_sets_infer(args): + import logging + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + def load_config(cfg_path: str) -> dict: + """Load config from the YAML file that was dumped during training.""" + if not os.path.exists(cfg_path): + raise FileNotFoundError(f"Could not find config file: {cfg_path}") + with open(cfg_path, "r") as f: + cfg = yaml.safe_load(f) + return cfg + + # Load the config + config_path = os.path.join(args.output_dir, "config.yaml") + cfg = load_config(config_path) + logger.info(f"Loaded config from {config_path}") + + # Find run output directory & load data module + run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}") + data_module = PerturbationDataModule.load_state(data_module_path) + data_module.setup(stage="test") + logger.info(f"Loaded data module from {data_module_path}") + + # Get perturbation dimensions and mapping from data module + var_dims = data_module.get_var_dims() + pert_dim = var_dims["pert_dim"] + + # Load model + logger.info(f"Loading model from checkpoint: {args.checkpoint}") + model = PertSetsPerturbationModel.load_from_checkpoint(args.checkpoint) + model.eval() + cell_sentence_len = model.cell_sentence_len + device = next(model.parameters()).device + + # Load AnnData + logger.info(f"Loading AnnData from: {args.adata}") + adata = sc.read_h5ad(args.adata) + + # Optionally filter by cell type + if args.celltype_col is not None and args.celltypes is not None: + celltypes = [ct.strip() for ct in args.celltypes.split(",")] + if args.celltype_col not in adata.obs: + raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.") + initial_n = adata.n_obs + adata = adata[adata.obs[args.celltype_col].isin(celltypes)].copy() + logger.info(f"Filtered AnnData to {adata.n_obs} cells of types {celltypes} (from {initial_n} cells)") + elif args.celltype_col is not None: + if args.celltype_col not in adata.obs: + raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.") + logger.info(f"No cell type filtering applied, but cell type column '{args.celltype_col}' is available.") + + # Get input features + if args.embed_key in adata.obsm: + X = adata.obsm[args.embed_key] + logger.info(f"Using adata.obsm['{args.embed_key}'] as input features: shape {X.shape}") + else: + X = adata.X + logger.info(f"Using adata.X as input features: shape {X.shape}") + + # Prepare perturbation tensor using the data module's mapping + pert_names = adata.obs[args.pert_col].values + pert_tensor = torch.zeros((len(pert_names), pert_dim), device="cpu") # Keep on CPU initially + logger.info(f"Perturbation tensor shape: {pert_tensor.shape}") + + # Use data module's perturbation mapping + pert_onehot_map = data_module.pert_onehot_map + + logger.info(f"Data module has {len(pert_onehot_map)} perturbations in mapping") + logger.info(f"First 10 perturbations in data module: {list(pert_onehot_map.keys())[:10]}") + + unique_pert_names = sorted(set(pert_names)) + logger.info(f"AnnData has {len(unique_pert_names)} unique perturbations") + logger.info(f"First 10 perturbations in AnnData: {unique_pert_names[:10]}") + + # Check overlap + overlap = set(unique_pert_names) & set(pert_onehot_map.keys()) + logger.info(f"Overlap between AnnData and data module: {len(overlap)} perturbations") + if len(overlap) < len(unique_pert_names): + missing = set(unique_pert_names) - set(pert_onehot_map.keys()) + logger.warning(f"Missing perturbations: {list(missing)[:10]}") + + # Check if there's a control perturbation that might match + control_pert = data_module.get_control_pert() + logger.info(f"Control perturbation in data module: '{control_pert}'") + + matched_count = 0 + for idx, name in enumerate(pert_names): + if name in pert_onehot_map: + pert_tensor[idx] = pert_onehot_map[name] + matched_count += 1 + else: + # For now, use control perturbation as fallback + if control_pert in pert_onehot_map: + pert_tensor[idx] = pert_onehot_map[control_pert] + else: + # Use first available perturbation as fallback + first_pert = list(pert_onehot_map.keys())[0] + pert_tensor[idx] = pert_onehot_map[first_pert] + + logger.info(f"Matched {matched_count} out of {len(pert_names)} perturbations") + + # Process in batches with progress bar + # Use cell_sentence_len as batch size since model expects this + n_samples = len(pert_names) + batch_size = cell_sentence_len # Model requires this exact batch size + n_batches = (n_samples + batch_size - 1) // batch_size # Ceiling division + + logger.info( + f"Running inference on {n_samples} samples in {n_batches} batches of size {batch_size} (model's cell_sentence_len)..." + ) + + all_preds = [] + + with torch.no_grad(): + progress_bar = tqdm(total=n_samples, desc="Processing samples", unit="samples") + + for batch_idx in range(n_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, n_samples) + current_batch_size = end_idx - start_idx + + # Get batch data + X_batch = torch.tensor(X[start_idx:end_idx], dtype=torch.float32).to(device) + pert_batch = pert_tensor[start_idx:end_idx].to(device) + pert_names_batch = pert_names[start_idx:end_idx].tolist() + + # Pad the batch to cell_sentence_len if it's the last incomplete batch + if current_batch_size < cell_sentence_len: + # Pad with zeros for embeddings + padding_size = cell_sentence_len - current_batch_size + X_pad = torch.zeros((padding_size, X_batch.shape[1]), device=device) + X_batch = torch.cat([X_batch, X_pad], dim=0) + + # Pad perturbation tensor with control perturbation + pert_pad = torch.zeros((padding_size, pert_batch.shape[1]), device=device) + if control_pert in pert_onehot_map: + pert_pad[:] = pert_onehot_map[control_pert].to(device) + else: + pert_pad[:, 0] = 1 # Default to first perturbation + pert_batch = torch.cat([pert_batch, pert_pad], dim=0) + + # Extend perturbation names + pert_names_batch.extend([control_pert] * padding_size) + + # Prepare batch - use same format as working code + batch = { + "ctrl_cell_emb": X_batch, + "pert_emb": pert_batch, # Keep as 2D tensor + "pert_name": pert_names_batch, + "batch": torch.zeros((1, cell_sentence_len), device=device), # Use (1, cell_sentence_len) + } + + # Run inference on batch using padded=False like in working code + batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) + + # Extract predictions from the dictionary returned by predict_step + # Use gene decoder output if available, otherwise use latent predictions + if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: + # Use gene space predictions (from decoder) + pred_tensor = batch_preds["pert_cell_counts_preds"] + else: + # Use latent space predictions + pred_tensor = batch_preds["preds"] + + # Only keep predictions for the actual samples (not padding) + actual_preds = pred_tensor[:current_batch_size] + all_preds.append(actual_preds.cpu().numpy()) + + # Update progress bar + progress_bar.update(current_batch_size) + + progress_bar.close() + + # Concatenate all predictions + preds_np = np.concatenate(all_preds, axis=0) + + # Save predictions to AnnData + pred_key = "model_preds" + adata.obsm[pred_key] = preds_np + output_path = args.output or args.adata.replace(".h5ad", "_with_preds.h5ad") + adata.write_h5ad(output_path) + logger.info(f"Saved predictions to {output_path} (in adata.obsm['{pred_key}'])") + + +def main(): + parser = argparse.ArgumentParser(description="Run inference on AnnData with a trained model checkpoint.") + add_arguments_infer(parser) + args = parser.parse_args() + run_sets_infer(args) + + +if __name__ == "__main__": + main() diff --git a/src/state_sets/_cli/_sets/_predict.py b/src/state_sets/_cli/_sets/_predict.py index 08982062..8164dcba 100644 --- a/src/state_sets/_cli/_sets/_predict.py +++ b/src/state_sets/_cli/_sets/_predict.py @@ -366,6 +366,7 @@ def load_config(cfg_path: str) -> dict: outdir=results_dir, prefix=ct, pdex_kwargs=pdex_kwargs, + batch_size=2048, ) evaluator.compute( diff --git a/src/state_sets/_cli/_sets/_train.py b/src/state_sets/_cli/_sets/_train.py index 987b8228..1a7cf50e 100644 --- a/src/state_sets/_cli/_sets/_train.py +++ b/src/state_sets/_cli/_sets/_train.py @@ -115,6 +115,28 @@ def run_sets_train(cfg: DictConfig): data_module.save_state(f) data_module.setup(stage="fit") + dl = data_module.train_dataloader() + print("num_workers:", dl.num_workers) + print("batch size:", dl.batch_size) + + var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} + if cfg["data"]["kwargs"]["output_space"] == "gene": + gene_dim = var_dims.get("hvg_dim", 2000) # fallback if key missing + else: + gene_dim = var_dims.get("gene_dim", 2000) # fallback if key missing + latent_dim = var_dims["output_dim"] # same as model.output_dim + hidden_dims = cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]) + + decoder_cfg = dict( + latent_dim=latent_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), + residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), + ) + + # tuck it into the kwargs that will reach the LightningModule + cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg if cfg["model"]["name"].lower() in ["cpa", "scvi"] or cfg["model"]["name"].lower().startswith("scgpt"): cfg["model"]["kwargs"]["n_cell_types"] = len(data_module.celltype_onehot_map) @@ -130,7 +152,9 @@ def run_sets_train(cfg: DictConfig): data_module.get_var_dims(), ) - # Set up logging + print( + f"Model created. Estimated params size: {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3:.2f} GB" + ) loggers = get_loggers( output_dir=cfg["output_dir"], name=cfg["name"], @@ -198,7 +222,9 @@ def run_sets_train(cfg: DictConfig): del trainer_kwargs["max_steps"] # Build trainer + print(f"Building trainer with kwargs: {trainer_kwargs}") trainer = pl.Trainer(**trainer_kwargs) + print("Trainer built successfully") # Load checkpoint if exists checkpoint_path = join(ckpt_callbacks[0].dirpath, "last.ckpt") @@ -207,12 +233,17 @@ def run_sets_train(cfg: DictConfig): else: logging.info(f"!! Resuming training from {checkpoint_path} !!") + print(f"Model device: {next(model.parameters()).device}") + print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") + print(f"CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") + logger.info("Starting trainer fit.") # if a checkpoint does not exist, start with the provided checkpoint # this is mainly used for pretrain -> finetune workflows manual_init = cfg["model"]["kwargs"].get("init_from", None) if checkpoint_path is None and manual_init is not None: + print(f"Loading manual checkpoint from {manual_init}") checkpoint_path = manual_init device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(checkpoint_path, map_location=device) @@ -253,6 +284,7 @@ def run_sets_train(cfg: DictConfig): # Load the filtered state dict model.load_state_dict(filtered_state, strict=False) + print("About to call trainer.fit() with manual checkpoint...") # Train - for clarity we pass None trainer.fit( @@ -260,13 +292,18 @@ def run_sets_train(cfg: DictConfig): datamodule=data_module, ckpt_path=None, ) + print("trainer.fit() completed with manual checkpoint") else: + print(f"About to call trainer.fit() with checkpoint_path={checkpoint_path}") # Train trainer.fit( model, datamodule=data_module, ckpt_path=checkpoint_path, ) + print("trainer.fit() completed") + + print("Training completed, saving final checkpoint...") # at this point if checkpoint_path does not exist, manually create one checkpoint_path = join(ckpt_callbacks[0].dirpath, "final.ckpt") diff --git a/src/state_sets/_cli/_state/_embed.py b/src/state_sets/_cli/_state/_embed.py index 0e2a3df1..396999ce 100644 --- a/src/state_sets/_cli/_state/_embed.py +++ b/src/state_sets/_cli/_state/_embed.py @@ -1,16 +1,13 @@ import argparse as ap +import torch def add_arguments_embed(parser: ap.ArgumentParser): """Add arguments for state embedding CLI.""" - parser.add_argument("--checkpoint", required=True, help="Path to the model checkpoint file") - parser.add_argument("--config", required=True, help="Path to the model training config") + parser.add_argument("--model-folder", required=True, help="Path to the model checkpoint folder") parser.add_argument("--input", required=True, help="Path to input anndata file (h5ad)") parser.add_argument("--output", required=True, help="Path to output embedded anndata file (h5ad)") - parser.add_argument("--dataset-name", default="perturbation", help="Dataset name to be used in dataloader creation") - parser.add_argument("--gpu", action="store_true", help="Use GPU if available") - parser.add_argument("--filter", action="store_true", help="Filter gene set to our esm embeddings only.") - parser.add_argument("--embed-key", help="Name of key to store embeddings") + parser.add_argument("--embed-key", default="X_state", help="Name of key to store embeddings") def run_state_embed(args: ap.ArgumentParser): @@ -18,6 +15,7 @@ def run_state_embed(args: ap.ArgumentParser): Compute embeddings for an input anndata file using a pre-trained VCI model checkpoint. """ import os + import glob import logging from omegaconf import OmegaConf @@ -27,13 +25,23 @@ def run_state_embed(args: ap.ArgumentParser): from ...state.inference import Inference - # Load configuration - logger.info(f"Loading config from {args.config}") - conf = OmegaConf.load(args.config) + # look in the model folder with glob for *.ckpt, get the first one, and print it + model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt")) + if not model_files: + logger.error(f"No model checkpoint found in {args.model_folder}") + raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}") + args.checkpoint = model_files[0] + logger.info(f"Using model checkpoint: {args.checkpoint}") # Create inference object logger.info("Creating inference object") - inferer = Inference(conf) + embedding_file = os.path.join(args.model_folder, "protein_embeddings.pt") + protein_embeds = torch.load(embedding_file, weights_only=False, map_location="cpu") + + config_file = os.path.join(args.model_folder, "config.yaml") + conf = OmegaConf.load(config_file) + + inferer = Inference(cfg=conf, protein_embeds=protein_embeds) # Load model from checkpoint logger.info(f"Loading model from checkpoint: {args.checkpoint}") @@ -53,7 +61,6 @@ def run_state_embed(args: ap.ArgumentParser): input_adata_path=args.input, output_adata_path=args.output, emb_key=args.embed_key, - dataset_name=args.dataset_name, ) logger.info("Embedding computation completed successfully!") diff --git a/src/state_sets/configs/model/tahoe_decoder_test.yaml b/src/state_sets/configs/model/tahoe_decoder_test.yaml new file mode 100644 index 00000000..680e1812 --- /dev/null +++ b/src/state_sets/configs/model/tahoe_decoder_test.yaml @@ -0,0 +1,44 @@ +name: PertSets +checkpoint: null +device: cuda + +kwargs: + cell_set_len: 512 + decoder_hidden_dims: [2048, 2048, 2048] + blur: 0.05 + hidden_dim: 696 # hidden dimension going into the transformer backbone + loss: energy + confidence_head: False + n_encoder_layers: 4 + n_decoder_layers: 4 + predict_residual: True + softplus: True + freeze_pert: False + transformer_decoder: False + finetune_vci_decoder: False + residual_decoder: False + decoder_loss_weight: 1.0 + batch_encoder: False + nb_decoder: False + mask_attn: False + distributional_loss: energy + init_from: null + transformer_backbone_key: llama + transformer_backbone_kwargs: + max_position_embeddings: ${model.kwargs.cell_set_len} + hidden_size: ${model.kwargs.hidden_dim} + intermediate_size: 2784 + num_hidden_layers: 8 + num_attention_heads: 12 + num_key_value_heads: 12 + head_dim: 58 + use_cache: false + attention_dropout: 0.0 + hidden_dropout: 0.0 + layer_norm_eps: 1e-6 + pad_token_id: 0 + bos_token_id: 1 + eos_token_id: 2 + tie_word_embeddings: false + rotary_dim: 0 + use_rotary_embeddings: false diff --git a/src/state_sets/configs/model/tahoe_llama_93133848.yaml b/src/state_sets/configs/model/tahoe_llama_93133848.yaml index f5152789..17b183d7 100644 --- a/src/state_sets/configs/model/tahoe_llama_93133848.yaml +++ b/src/state_sets/configs/model/tahoe_llama_93133848.yaml @@ -4,6 +4,7 @@ device: cuda kwargs: cell_set_len: 512 + decoder_hidden_dims: [1024, 1024, 512] blur: 0.05 hidden_dim: 696 # hidden dimension going into the transformer backbone loss: energy diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 6c837541..50ab1d3f 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn from lightning.pytorch import LightningModule +import typing as tp from .utils import get_loss_fn @@ -147,9 +148,11 @@ def __init__( batch_size: int = 64, gene_dim: int = 5000, hvg_dim: int = 2001, + decoder_cfg: dict | None = None, **kwargs, ): super().__init__() + self.decoder_cfg = decoder_cfg self.save_hyperparameters() # Core architecture settings @@ -158,6 +161,8 @@ def __init__( self.output_dim = output_dim self.pert_dim = pert_dim self.batch_dim = batch_dim + self.gene_dim = gene_dim + self.hvg_dim = hvg_dim if kwargs.get("batch_encoder", False): self.batch_dim = batch_dim @@ -176,36 +181,7 @@ def __init__( self.dropout = dropout self.lr = lr self.loss_fn = get_loss_fn(loss_fn) - - # this will either decode to hvg space if output space is a gene, - # or to transcriptome space if output space is all. done this way to maintain - # backwards compatibility with the old models - self.gene_decoder = None - gene_dim = hvg_dim if output_space == "gene" else gene_dim - if (embed_key and embed_key != "X_hvg" and output_space == "gene") or ( - embed_key and output_space == "all" - ): # we should be able to decode from hvg to all - if gene_dim > 10000: - hidden_dims = [1024, 512, 256] - else: - if "DMSO_TF" in self.control_pert: - if self.residual_decoder: - hidden_dims = [2058, 2058, 2058, 2058, 2058] - else: - hidden_dims = [4096, 2048, 2048] - elif "PBS" in self.control_pert: - hidden_dims = [2048, 1024, 1024] - else: - hidden_dims = [1024, 1024, 512] # make this config - - self.gene_decoder = LatentToGeneDecoder( - latent_dim=self.output_dim, - gene_dim=gene_dim, - hidden_dims=hidden_dims, - dropout=dropout, - residual_decoder=self.residual_decoder, - ) - logger.info(f"Initialized gene decoder for embedding {embed_key} to gene space") + self._build_decoder() def transfer_batch_to_device(self, batch, device, dataloader_idx: int): return {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} @@ -215,6 +191,55 @@ def _build_networks(self): """Build the core neural network components.""" pass + def _build_decoder(self): + """Create self.gene_decoder from self.decoder_cfg (or leave None).""" + if self.decoder_cfg is None: + self.gene_decoder = None + return + self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) + + def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: + """ + Lightning calls this *before* the checkpoint's state_dict is loaded. + Re-create the decoder using the exact hyper-parameters saved in the ckpt, + so that parameter shapes match and load_state_dict succeeds. + """ + if "decoder_cfg" in checkpoint["hyper_parameters"]: + self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] + self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) + logger.info(f"Loaded decoder from checkpoint decoder_cfg: {self.decoder_cfg}") + else: + # Only fall back to old logic if no decoder_cfg was saved + self.decoder_cfg = None + self._build_decoder() + logger.info(f"DEBUG: output_space: {self.output_space}") + if self.gene_decoder is None: + gene_dim = self.hvg_dim if self.output_space == "gene" else self.gene_dim + logger.info(f"DEBUG: gene_dim: {gene_dim}") + if (self.embed_key and self.embed_key != "X_hvg" and self.output_space == "gene") or ( + self.embed_key and self.output_space == "all" + ): # we should be able to decode from hvg to all + logger.info(f"DEBUG: Creating gene_decoder, checking conditions...") + if gene_dim > 10000: + hidden_dims = [1024, 512, 256] + else: + if "DMSO_TF" in self.control_pert: + if self.residual_decoder: + hidden_dims = [2058, 2058, 2058, 2058, 2058] + else: + hidden_dims = [4096, 2048, 2048] + else: + hidden_dims = [1024, 1024, 512] # make this config + + self.gene_decoder = LatentToGeneDecoder( + latent_dim=self.output_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=self.dropout, + residual_decoder=self.residual_decoder, + ) + logger.info(f"Initialized gene decoder for embedding {self.embed_key} to gene space") + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Training step logic for both main model and decoder.""" # Get model predictions (in latent space) diff --git a/src/state_sets/sets/models/pert_sets.py b/src/state_sets/sets/models/pert_sets.py index a317b484..7c2c06d0 100644 --- a/src/state_sets/sets/models/pert_sets.py +++ b/src/state_sets/sets/models/pert_sets.py @@ -358,6 +358,8 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: # apply relu if specified and we output to HVG space is_gene_space = self.hparams["embed_key"] == "X_hvg" or self.hparams["embed_key"] is None + # logger.info(f"DEBUG: is_gene_space: {is_gene_space}") + # logger.info(f"DEBUG: self.gene_decoder: {self.gene_decoder}") if is_gene_space or self.gene_decoder is None: out_pred = self.relu(out_pred) diff --git a/src/state_sets/state/data/loader.py b/src/state_sets/state/data/loader.py index 6dba874a..40da2ec7 100644 --- a/src/state_sets/state/data/loader.py +++ b/src/state_sets/state/data/loader.py @@ -30,6 +30,7 @@ def create_dataloader( adata_name=None, shuffle=False, sentence_collator=None, + protein_embeds=None, ): """ Expected to be used for inference @@ -44,7 +45,9 @@ def create_dataloader( if data_dir: utils.get_dataset_cfg(cfg).data_dir = data_dir - dataset = FilteredGenesCounts(cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name) + dataset = FilteredGenesCounts( + cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name, protein_embeds=protein_embeds + ) if sentence_collator is None: sentence_collator = VCIDatasetSentenceCollator( cfg, valid_gene_mask=dataset.valid_gene_index, ds_emb_mapping_inference=dataset.ds_emb_map, is_train=False @@ -170,21 +173,19 @@ def get_dim(self) -> Dict[str, int]: class FilteredGenesCounts(H5adSentenceDataset): - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + def __init__( + self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None, protein_embeds=None + ) -> None: super(FilteredGenesCounts, self).__init__(cfg, test, datasets, shape_dict, adata, adata_name) self.valid_gene_index = {} + self.protein_embeds = protein_embeds # make sure we get training datasets - _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( - "/home/aadduri/state/h5ad_all.csv" - ) + self.datasets = [] + self.shapes_dict = {} + self.ds_emb_map = {} emb_cfg = utils.get_embedding_cfg(self.cfg) - try: - self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) - except (FileNotFoundError, IOError): - self.ds_emb_map = {} - # for inference, let's make sure this dataset's valid mask is available if adata_name is not None: # append it to self.datasets @@ -192,7 +193,7 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, self.shapes_dict[adata_name] = adata.shape # compute its embedding‐index vector - esm_data = torch.load(emb_cfg.all_embeddings, weights_only=False) + esm_data = self.protein_embeds or torch.load(emb_cfg.all_embeddings, weights_only=False) valid_genes_list = list(esm_data.keys()) # make a gene→global‐index lookup global_pos = {g: i for i, g in enumerate(valid_genes_list)} @@ -204,6 +205,7 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) if (new_mapping == -1).all(): # probably it contains ensembl id's instead + assert "gene_name" in adata.var.keys() gene_names = adata.var["gene_name"].values new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) @@ -211,7 +213,7 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, self.ds_emb_map[adata_name] = new_mapping if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: - esm_data = torch.load(utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False) + esm_data = self.protein_embeds or torch.load(emb_cfg.ds_emb_mapping, weights_only=False) valid_genes_list = list(esm_data.keys()) for name in self.datasets: if not utils.is_valid_uuid( diff --git a/src/state_sets/state/finetune_decoder.py b/src/state_sets/state/finetune_decoder.py index ef346ef7..5d5dc700 100644 --- a/src/state_sets/state/finetune_decoder.py +++ b/src/state_sets/state/finetune_decoder.py @@ -2,7 +2,7 @@ import torch from torch import nn -from vci.nn.model import LitUCEModel +from vci.nn.model import StateEmbeddingModel from vci.train.trainer import get_embeddings from vci.utils import get_embedding_cfg @@ -44,7 +44,7 @@ def load_model(self, checkpoint): # Import locally to avoid circular imports # Load and initialize model for eval - self.model = LitUCEModel.load_from_checkpoint(checkpoint, strict=False) + self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, strict=False) self.device = self.model.device # Load protein embeddings diff --git a/src/state_sets/state/inference.py b/src/state_sets/state/inference.py index 6bd905e3..2f17dfb4 100644 --- a/src/state_sets/state/inference.py +++ b/src/state_sets/state/inference.py @@ -9,7 +9,7 @@ from tqdm import tqdm from torch import nn -from .nn.model import LitUCEModel +from .nn.model import StateEmbeddingModel from .train.trainer import get_embeddings from .data import create_dataloader from .utils import get_embedding_cfg @@ -18,10 +18,10 @@ class Inference: - def __init__(self, cfg): + def __init__(self, cfg=None, protein_embeds=None): self.model = None self.collator = None - self.protein_embeds = None + self.protein_embeds = protein_embeds self._vci_conf = cfg def __load_dataset_meta(self, adata_path): @@ -94,17 +94,17 @@ def load_model(self, checkpoint): raise ValueError("Model already initialized") # Load and initialize model for eval - self.model = LitUCEModel.load_from_checkpoint( - checkpoint, strict=False, cfg=self._vci_conf - ) ### THIS IS THE LINE THAT FAILS - all_pe = get_embeddings(self._vci_conf) - all_pe.requires_grad = False + self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, strict=False) + all_pe = self.protein_embeds or get_embeddings(self._vci_conf) + if isinstance(all_pe, dict): + all_pe = torch.vstack(list(all_pe.values())) self.model.pe_embedding = nn.Embedding.from_pretrained(all_pe) self.model.pe_embedding.to(self.model.device) self.model.binary_decoder.requires_grad = False self.model.eval() - self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) + if self.protein_embeds is None: + self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) def init_from_model(self, model, protein_embeds=None): """ @@ -149,10 +149,11 @@ def encode_adata( dataloader = create_dataloader( self._vci_conf, adata=adata, - adata_name=dataset_name, + adata_name=dataset_name or "inference", shape_dict=shape_dict, data_dir=os.path.dirname(input_adata_path), shuffle=False, + protein_embeds=self.protein_embeds, ) all_embeddings = [] @@ -173,14 +174,6 @@ def encode_adata( adata.obsm[emb_key] = all_embeddings adata.write_h5ad(output_adata_path) - # This streaming approach was not working (h5 files could not be opened) - # for embeddings in tqdm(self.encode(dataloader), - # total=len(dataloader), - # desc='Encoding'): - # self._save_data(input_adata_path, output_adata_path, emb_key, embeddings) - - # return output_adata_path - def decode_from_file(self, adata_path, emb_key: str, read_depth=None, batch_size=64): adata = anndata.read_h5ad(adata_path) genes = adata.var.index @@ -205,7 +198,7 @@ def decode_from_adata(self, adata, genes, emb_key: str, read_depth=None, batch_s task_counts = torch.full((cell_embeds_batch.shape[0],), read_depth, device=self.model.device) else: task_counts = None - merged_embs = LitUCEModel.resize_batch(cell_embeds_batch, gene_embeds, task_counts) + merged_embs = StateEmbeddingModel.resize_batch(cell_embeds_batch, gene_embeds, task_counts) logprobs_batch = self.model.binary_decoder(merged_embs) logprobs_batch = logprobs_batch.detach().cpu().numpy() yield logprobs_batch.squeeze() diff --git a/src/state_sets/state/nn/flash_transformer.py b/src/state_sets/state/nn/flash_transformer.py index cbb9e6b3..ba64b167 100644 --- a/src/state_sets/state/nn/flash_transformer.py +++ b/src/state_sets/state/nn/flash_transformer.py @@ -1,7 +1,6 @@ # File: vci/flash_transformer.py """ -This module implements a Transformer encoder layer that uses Flash Attention. -It provides a FlashTransformerEncoderLayer and a FlashTransformerEncoder. +This module implements a Transformer encoder layer. """ import torch diff --git a/src/state_sets/state/nn/model.py b/src/state_sets/state/nn/model.py index 86179870..dbbefb35 100644 --- a/src/state_sets/state/nn/model.py +++ b/src/state_sets/state/nn/model.py @@ -83,53 +83,11 @@ def forward(self, x: Tensor) -> Tensor: return self.dropout(x) -class RDAEncoder(nn.Module): - """ - Map a scalar read-depth d -> (mu, log_sigma). - latent_dim is whatever you used for z_count. - """ - - def __init__(self, latent_dim: int): - super().__init__() - self.latent_dim = latent_dim - self.net = nn.Sequential( - # nn.Linear(2, latent_dim, bias=True), - SkipBlock(self.latent_dim), - SkipBlock(self.latent_dim), - ) - - def forward(self, mu, sigma): - """ - Parameters - ---------- - mu : tensor or float -- shape (B,) or (B,1) or scalar - sigma : tensor or float -- same shape as `mu`; *must be ≥0* - - Returns - ------- - z_latent : tensor of shape (B, latent_dim) - Sampled from 𝒩(μ, σ²I) and processed by `self.net`. - """ - # Ensure tensors, move to same device as the network - mu = mu.view(-1, 1) - sigma = sigma.view(-1, 1) - - # Sample eps ~ N(0, I) and build z - eps = torch.randn(mu.size(0), self.latent_dim, device=mu.device) - z = mu.expand(-1, self.latent_dim) + sigma.expand(-1, self.latent_dim) * eps - - # z = torch.cat((mu, sigma), dim=1) # (B, 2) - - # Pass through two SkipBlocks (or identity if you prefer) - z = self.net(z) # (B, latent_dim) - return z - - def nanstd(x): return torch.sqrt(torch.nanmean(torch.pow(x - torch.nanmean(x, dim=-1).unsqueeze(-1), 2), dim=-1)) -class LitUCEModel(L.LightningModule): +class StateEmbeddingModel(L.LightningModule): def __init__( self, token_dim: int, @@ -195,11 +153,7 @@ def __init__( if compiled: self.decoder = torch.compile(self.decoder) - if self.cfg.model.get("sample_rda", False): - self.z_dim_rd = 128 - self.z_encoder = RDAEncoder(self.z_dim_rd) - else: - self.z_dim_rd = 1 if self.cfg.model.rda else 0 + self.z_dim_rd = 1 if self.cfg.model.rda else 0 self.z_dim_ds = 10 if self.cfg.model.get("dataset_correction", False) else 0 self.z_dim = self.z_dim_rd + self.z_dim_ds @@ -362,31 +316,22 @@ def _predict_exp_for_adata(self, adata, dataset_name, pert_col): _, _, _, emb, ds_emb = self._compute_embedding_for_batch(batch) # now decode from the embedding - if self.z_dim_rd > 1: - Y = batch[2].to(self.device) - nan_y = Y.masked_fill(Y == 0, float("nan"))[:, : self.cfg.dataset.P + self.cfg.dataset.N] - mu = torch.nanmean(nan_y, dim=1) if self.cfg.model.rda else None - sigma = nanstd(nan_y) if self.cfg.model.rda else None - sampled_rda = self.z_encoder(mu, sigma) - task_counts = None - elif self.z_dim_rd == 1: + task_counts = None + sampled_rda = None + if self.z_dim_rd == 1: Y = batch[2].to(self.device) nan_y = Y.masked_fill(Y == 0, float("nan"))[:, : self.cfg.dataset.P + self.cfg.dataset.N] task_counts = torch.nanmean(nan_y, dim=1) if self.cfg.model.rda else None sampled_rda = None - else: - task_counts = None - sampled_rda = None + ds_emb = None if self.dataset_token is not None: ds_emb = self.dataset_embedder(ds_emb) - else: - ds_emb = None emb_batches.append(emb.detach().cpu().numpy()) ds_emb_batches.append(ds_emb.detach().cpu().numpy()) - merged_embs = LitUCEModel.resize_batch(emb, gene_embeds, task_counts, sampled_rda, ds_emb) + merged_embs = StateEmbeddingModel.resize_batch(emb, gene_embeds, task_counts, sampled_rda, ds_emb) logprobs_batch = self.binary_decoder(merged_embs) logprobs_batch = logprobs_batch.detach().cpu().numpy() logprob_batches.append(logprobs_batch.squeeze()) @@ -450,25 +395,16 @@ def forward(self, src: Tensor, mask: Tensor, counts=None, dataset_nums=None): src + count_emb ) # should both be B x H x self.d_model, or B x H + 1 x self.d_model if dataset correction - if self.training: - # random chance 10% to set mask to None - if self.cfg.model.get("variable_masking", False) and np.random.rand() < 0.1: - mask = None - else: - mask = mask.to(self.device) - output = self.transformer_encoder(src, src_key_padding_mask=mask) - else: - output = self.transformer_encoder(src, src_key_padding_mask=None) + output = self.transformer_encoder(src, src_key_padding_mask=None) gene_output = self.decoder(output) # batch x seq_len x 128 # In the new format, the cls token, which is at the 0 index mark, is the output. embedding = gene_output[:, 0, :] # select only the CLS token. embedding = nn.functional.normalize(embedding, dim=1) # Normalize. # we must be in train mode to use dataset correction + dataset_emb = None if self.dataset_token is not None: dataset_emb = gene_output[:, -1, :] - else: - dataset_emb = None return gene_output, embedding, dataset_emb @@ -478,16 +414,7 @@ def shared_step(self, batch, batch_idx): z = embs.unsqueeze(1).repeat(1, X.shape[1], 1) # CLS token - if self.z_dim_rd > 1: - # your code here that computes mu and std dev from Y - nan_y = Y.masked_fill(Y == 0, float("nan"))[:, : self.cfg.dataset.P + self.cfg.dataset.N] - mu = torch.nanmean(nan_y, dim=1) if self.cfg.model.rda else None - sigma = nanstd(nan_y) if self.cfg.model.rda else None - - reshaped_counts = self.z_encoder(mu, sigma).unsqueeze(1) - reshaped_counts = reshaped_counts.expand(X.shape[0], X.shape[1], reshaped_counts.shape[2]) - combine = torch.cat((X, z, reshaped_counts), dim=2) - elif self.z_dim_rd == 1: + if self.z_dim_rd == 1: mu = torch.nanmean(Y.masked_fill(Y == 0, float("nan")), dim=1) if self.cfg.model.rda else None reshaped_counts = mu.unsqueeze(1).unsqueeze(2) reshaped_counts = reshaped_counts.repeat(1, X.shape[1], 1) diff --git a/src/state_sets/state/train/trainer.py b/src/state_sets/state/train/trainer.py index bcd5f615..604a6cbb 100644 --- a/src/state_sets/state/train/trainer.py +++ b/src/state_sets/state/train/trainer.py @@ -11,7 +11,7 @@ from lightning.pytorch.strategies import DDPStrategy from zclip import ZClipLightningCallback -from ..nn.model import LitUCEModel +from ..nn.model import StateEmbeddingModel from ..data import H5adSentenceDataset, VCIDatasetSentenceCollator from ..train.callbacks import LogLR, ProfilerCallback, ResumeCallback, EMACallback, PerfProfilerCallback from ..utils import get_latest_checkpoint, get_embedding_cfg, get_dataset_cfg @@ -74,7 +74,7 @@ def main(cfg): generator=generator, ) - model = LitUCEModel( + model = StateEmbeddingModel( token_dim=get_embedding_cfg(cfg).size, d_model=cfg.model.emsize, nhead=cfg.model.nhead, @@ -141,7 +141,7 @@ def main(cfg): accumulate_grad_batches=cfg.optimizer.gradient_accumulation_steps, precision="bf16-mixed", strategy=DDPStrategy( - process_group_backend="nccl", # if cfg.experiment.num_nodes == 1 else "gloo", + process_group_backend="nccl", find_unused_parameters=False, timeout=timedelta(seconds=cfg.experiment.get("ddp_timeout", 3600)), ),