From 98dae89be908e3bd928120df2f7e85b83351cdf7 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Wed, 11 Jun 2025 15:30:18 -0700 Subject: [PATCH 01/16] decoder fix started --- src/state_sets/_cli/_sets/_train.py | 17 +++++++ .../configs/model/tahoe_decoder_test.yaml | 44 ++++++++++++++++++ .../configs/model/tahoe_llama_93133848.yaml | 1 + src/state_sets/sets/models/base.py | 46 +++++++++++++++---- 4 files changed, 98 insertions(+), 10 deletions(-) create mode 100644 src/state_sets/configs/model/tahoe_decoder_test.yaml diff --git a/src/state_sets/_cli/_sets/_train.py b/src/state_sets/_cli/_sets/_train.py index 987b8228..9f8ecc50 100644 --- a/src/state_sets/_cli/_sets/_train.py +++ b/src/state_sets/_cli/_sets/_train.py @@ -115,6 +115,23 @@ def run_sets_train(cfg: DictConfig): data_module.save_state(f) data_module.setup(stage="fit") + + var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} + gene_dim = var_dims.get("gene_dim", 5000) # fallback if key missing + latent_dim = cfg["model"]["kwargs"]["output_dim"] # same as model.output_dim + # optional: let user override from CLI/YAML + hidden_dims = cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]) + + decoder_cfg = dict( + latent_dim = latent_dim, + gene_dim = gene_dim, + hidden_dims = hidden_dims, + dropout = cfg["model"]["kwargs"].get("decoder_dropout", 0.1), + residual_decoder = cfg["model"]["kwargs"].get("residual_decoder", False), + ) + + # tuck it into the kwargs that will reach the LightningModule + cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg if cfg["model"]["name"].lower() in ["cpa", "scvi"] or cfg["model"]["name"].lower().startswith("scgpt"): cfg["model"]["kwargs"]["n_cell_types"] = len(data_module.celltype_onehot_map) diff --git a/src/state_sets/configs/model/tahoe_decoder_test.yaml b/src/state_sets/configs/model/tahoe_decoder_test.yaml new file mode 100644 index 00000000..1a0c985d --- /dev/null +++ b/src/state_sets/configs/model/tahoe_decoder_test.yaml @@ -0,0 +1,44 @@ +name: PertSets +checkpoint: null +device: cuda + +kwargs: + cell_set_len: 512 + decoder_hidden_dims: [2048, 2048, 2048] + blur: 0.05 + hidden_dim: 1488 # hidden dimension going into the transformer backbone + loss: energy + confidence_head: False + n_encoder_layers: 4 + n_decoder_layers: 4 + predict_residual: True + softplus: True + freeze_pert: False + transformer_decoder: False + finetune_vci_decoder: False + residual_decoder: False + decoder_loss_weight: 1.0 + batch_encoder: False + nb_decoder: False + mask_attn: False + distributional_loss: energy + init_from: null + transformer_backbone_key: llama + transformer_backbone_kwargs: + max_position_embeddings: ${model.kwargs.cell_set_len} + hidden_size: ${model.kwargs.hidden_dim} + intermediate_size: 5952 + num_hidden_layers: 6 + num_attention_heads: 12 + num_key_value_heads: 12 + head_dim: 124 + use_cache: false + attention_dropout: 0.0 + hidden_dropout: 0.0 + layer_norm_eps: 1e-6 + pad_token_id: 0 + bos_token_id: 1 + eos_token_id: 2 + tie_word_embeddings: false + rotary_dim: 0 + use_rotary_embeddings: false diff --git a/src/state_sets/configs/model/tahoe_llama_93133848.yaml b/src/state_sets/configs/model/tahoe_llama_93133848.yaml index f5152789..17b183d7 100644 --- a/src/state_sets/configs/model/tahoe_llama_93133848.yaml +++ b/src/state_sets/configs/model/tahoe_llama_93133848.yaml @@ -4,6 +4,7 @@ device: cuda kwargs: cell_set_len: 512 + decoder_hidden_dims: [1024, 1024, 512] blur: 0.05 hidden_dim: 696 # hidden dimension going into the transformer backbone loss: energy diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 6c837541..92f5c473 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn from lightning.pytorch import LightningModule +import typing as tp from .utils import get_loss_fn @@ -147,9 +148,11 @@ def __init__( batch_size: int = 64, gene_dim: int = 5000, hvg_dim: int = 2001, + decoder_cfg: dict | None = None, **kwargs, ): super().__init__() + self.decoder_cfg = decoder_cfg self.save_hyperparameters() # Core architecture settings @@ -196,16 +199,22 @@ def __init__( elif "PBS" in self.control_pert: hidden_dims = [2048, 1024, 1024] else: - hidden_dims = [1024, 1024, 512] # make this config - - self.gene_decoder = LatentToGeneDecoder( - latent_dim=self.output_dim, - gene_dim=gene_dim, - hidden_dims=hidden_dims, - dropout=dropout, - residual_decoder=self.residual_decoder, - ) - logger.info(f"Initialized gene decoder for embedding {embed_key} to gene space") + if "DMSO_TF" in self.control_pert: + if self.residual_decoder: + hidden_dims = [2058, 2058, 2058, 2058, 2058] + else: + hidden_dims = [4096, 2048, 2048] + else: + hidden_dims = [1024, 1024, 512] # make this config + + self.gene_decoder = LatentToGeneDecoder( + latent_dim=self.output_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=dropout, + residual_decoder=self.residual_decoder, + ) + logger.info(f"Initialized gene decoder for embedding {embed_key} to gene space") def transfer_batch_to_device(self, batch, device, dataloader_idx: int): return {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} @@ -215,6 +224,23 @@ def _build_networks(self): """Build the core neural network components.""" pass + def _build_decoder(self): + """Create self.gene_decoder from self.decoder_cfg (or leave None).""" + if self.decoder_cfg is None: + self.gene_decoder = None + return + self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) + + def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: + """ + Lightning calls this *before* the checkpoint's state_dict is loaded. + Re-create the decoder using the exact hyper-parameters saved in the ckpt, + so that parameter shapes match and load_state_dict succeeds. + """ + self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] + self._build_decoder() + + def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Training step logic for both main model and decoder.""" # Get model predictions (in latent space) From 2890319af54cd773f148d4ebb0e5dc68106a03ee Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Fri, 13 Jun 2025 13:38:33 -0700 Subject: [PATCH 02/16] base infer not working --- src/state_sets/_cli/_sets/_infer.py | 94 +++++++++++++++++++++++++++++ src/state_sets/_cli/_sets/_train.py | 4 +- 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 src/state_sets/_cli/_sets/_infer.py diff --git a/src/state_sets/_cli/_sets/_infer.py b/src/state_sets/_cli/_sets/_infer.py new file mode 100644 index 00000000..e8a71058 --- /dev/null +++ b/src/state_sets/_cli/_sets/_infer.py @@ -0,0 +1,94 @@ +import argparse +import scanpy as sc +import torch +import numpy as np +import os +import pandas as pd + +# Adjust this import to your project structure +from ...sets.models.pert_sets import PertSetsPerturbationModel + +def add_arguments_infer(parser: argparse.ArgumentParser): + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)") + parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") + parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features") + parser.add_argument("--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels") + parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") + + +def run_sets_infer(args): + import logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Load model + logger.info(f"Loading model from checkpoint: {args.checkpoint}") + model = PertSetsPerturbationModel.load_from_checkpoint(args.checkpoint) + model.eval() + device = next(model.parameters()).device + + # Use model's config for batch prep + pert_onehot_map = getattr(model, "pert_onehot_map", None) + pert_dim = model.pert_dim + + # Load AnnData + logger.info(f"Loading AnnData from: {args.adata}") + adata = sc.read_h5ad(args.adata) + + # Get input features + if args.embed_key in adata.obsm: + X = adata.obsm[args.embed_key] + logger.info(f"Using adata.obsm['{args.embed_key}'] as input features: shape {X.shape}") + else: + X = adata.X + logger.info(f"Using adata.X as input features: shape {X.shape}") + X = torch.tensor(X, dtype=torch.float32).to(device) + + # Prepare perturbation tensor using the model's map + pert_names = adata.obs[args.pert_col].values + pert_tensor = torch.zeros((len(pert_names), pert_dim), device=device) + if pert_onehot_map is not None: + for idx, name in enumerate(pert_names): + if name in pert_onehot_map: + pert_tensor[idx, pert_onehot_map[name]] = 1 + else: + # Optionally handle unknown perturbations + pass + else: + # Fallback: build map from AnnData (not recommended for production) + unique_perts = sorted(set(pert_names)) + pert_map = {name: i for i, name in enumerate(unique_perts)} + for idx, name in enumerate(pert_names): + pert_tensor[idx, pert_map[name]] = 1 + + # Prepare batch + batch = { + "ctrl_cell_emb": X, + "pert_emb": pert_tensor, + "pert_name": pert_names.tolist(), + # "batch": torch.zeros((1, cell_sentence_len), device=device) + } + # when do we need the batch num things + + # Inference + logger.info("Running inference...") + with torch.no_grad(): + preds = model.forward(batch) + preds_np = preds.cpu().numpy() + + # Save predictions to AnnData + pred_key = "model_preds" + adata.obsm[pred_key] = preds_np + output_path = args.output or args.adata.replace(".h5ad", "_with_preds.h5ad") + adata.write_h5ad(output_path) + logger.info(f"Saved predictions to {output_path} (in adata.obsm['{pred_key}'])") + + +def main(): + parser = argparse.ArgumentParser(description="Run inference on AnnData with a trained model checkpoint.") + add_arguments_infer(parser) + args = parser.parse_args() + run_sets_infer(args) + +if __name__ == "__main__": + main() diff --git a/src/state_sets/_cli/_sets/_train.py b/src/state_sets/_cli/_sets/_train.py index 9f8ecc50..586750cd 100644 --- a/src/state_sets/_cli/_sets/_train.py +++ b/src/state_sets/_cli/_sets/_train.py @@ -119,7 +119,6 @@ def run_sets_train(cfg: DictConfig): var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} gene_dim = var_dims.get("gene_dim", 5000) # fallback if key missing latent_dim = cfg["model"]["kwargs"]["output_dim"] # same as model.output_dim - # optional: let user override from CLI/YAML hidden_dims = cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]) decoder_cfg = dict( @@ -133,6 +132,9 @@ def run_sets_train(cfg: DictConfig): # tuck it into the kwargs that will reach the LightningModule cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg + cfg["data"]["kwargs"]["n_perts"] = len(data_module.pert_onehot_map) + cfg["model"]["kwargs"]["pert_onehot_map"] = data_module.pert_onehot_map + if cfg["model"]["name"].lower() in ["cpa", "scvi"] or cfg["model"]["name"].lower().startswith("scgpt"): cfg["model"]["kwargs"]["n_cell_types"] = len(data_module.celltype_onehot_map) cfg["model"]["kwargs"]["n_perts"] = len(data_module.pert_onehot_map) From 127bb1f0f06fc62394890c896c7d5855d304f8f0 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Mon, 16 Jun 2025 10:22:35 -0700 Subject: [PATCH 03/16] src --- src/state_sets/_cli/_sets/_infer.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/state_sets/_cli/_sets/_infer.py b/src/state_sets/_cli/_sets/_infer.py index e8a71058..2617151d 100644 --- a/src/state_sets/_cli/_sets/_infer.py +++ b/src/state_sets/_cli/_sets/_infer.py @@ -5,7 +5,6 @@ import os import pandas as pd -# Adjust this import to your project structure from ...sets.models.pert_sets import PertSetsPerturbationModel def add_arguments_infer(parser: argparse.ArgumentParser): @@ -14,6 +13,8 @@ def add_arguments_infer(parser: argparse.ArgumentParser): parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features") parser.add_argument("--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels") parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") + parser.add_argument("--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)") + parser.add_argument("--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)") def run_sets_infer(args): @@ -25,6 +26,7 @@ def run_sets_infer(args): logger.info(f"Loading model from checkpoint: {args.checkpoint}") model = PertSetsPerturbationModel.load_from_checkpoint(args.checkpoint) model.eval() + cell_sentence_len = model.cell_sentence_len device = next(model.parameters()).device # Use model's config for batch prep @@ -35,6 +37,19 @@ def run_sets_infer(args): logger.info(f"Loading AnnData from: {args.adata}") adata = sc.read_h5ad(args.adata) + # Optionally filter by cell type + if args.celltype_col is not None and args.celltypes is not None: + celltypes = [ct.strip() for ct in args.celltypes.split(",")] + if args.celltype_col not in adata.obs: + raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.") + initial_n = adata.n_obs + adata = adata[adata.obs[args.celltype_col].isin(celltypes)].copy() + logger.info(f"Filtered AnnData to {adata.n_obs} cells of types {celltypes} (from {initial_n} cells)") + elif args.celltype_col is not None: + if args.celltype_col not in adata.obs: + raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.") + logger.info(f"No cell type filtering applied, but cell type column '{args.celltype_col}' is available.") + # Get input features if args.embed_key in adata.obsm: X = adata.obsm[args.embed_key] @@ -66,11 +81,10 @@ def run_sets_infer(args): "ctrl_cell_emb": X, "pert_emb": pert_tensor, "pert_name": pert_names.tolist(), - # "batch": torch.zeros((1, cell_sentence_len), device=device) + "batch": torch.zeros((1, cell_sentence_len), device=device) } # when do we need the batch num things - # Inference logger.info("Running inference...") with torch.no_grad(): preds = model.forward(batch) From 8d33569e072043276e1fd516fc476b44563163f9 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Thu, 19 Jun 2025 22:07:48 -0700 Subject: [PATCH 04/16] infer working --- src/state_sets/__main__.py | 4 + src/state_sets/_cli/__init__.py | 3 +- src/state_sets/_cli/_sets/__init__.py | 4 +- src/state_sets/_cli/_sets/_infer.py | 180 +++++++++++++++++++----- src/state_sets/_cli/_sets/_infer2.py | 150 ++++++++++++++++++++ src/state_sets/sets/models/base.py | 36 ++++- src/state_sets/sets/models/pert_sets.py | 2 + 7 files changed, 344 insertions(+), 35 deletions(-) create mode 100644 src/state_sets/_cli/_sets/_infer2.py diff --git a/src/state_sets/__main__.py b/src/state_sets/__main__.py index ea4a7cbb..f0d27803 100644 --- a/src/state_sets/__main__.py +++ b/src/state_sets/__main__.py @@ -6,6 +6,7 @@ add_arguments_state, run_sets_predict, run_sets_train, + run_sets_infer, run_state_embed, run_state_train, ) @@ -60,6 +61,9 @@ def main(): case "predict": # For now, predict uses argparse and not hydra run_sets_predict(args) + case "infer": + # Run inference using argparse, similar to predict + run_sets_infer(args) if __name__ == "__main__": diff --git a/src/state_sets/_cli/__init__.py b/src/state_sets/_cli/__init__.py index 947c4960..5e057c70 100644 --- a/src/state_sets/_cli/__init__.py +++ b/src/state_sets/_cli/__init__.py @@ -1,4 +1,4 @@ -from ._sets import add_arguments_sets, run_sets_predict, run_sets_train +from ._sets import add_arguments_sets, run_sets_predict, run_sets_train, run_sets_infer from ._state import add_arguments_state, run_state_embed, run_state_train __all__ = [ @@ -7,5 +7,6 @@ "run_sets_train", "run_state_embed", "run_sets_predict", + "run_sets_infer", "run_state_train", ] diff --git a/src/state_sets/_cli/_sets/__init__.py b/src/state_sets/_cli/_sets/__init__.py index aede1b93..d64471bb 100644 --- a/src/state_sets/_cli/_sets/__init__.py +++ b/src/state_sets/_cli/_sets/__init__.py @@ -2,8 +2,9 @@ from ._predict import add_arguments_predict, run_sets_predict from ._train import add_arguments_train, run_sets_train +from ._infer import add_arguments_infer, run_sets_infer -__all__ = ["run_sets_train", "run_sets_predict", "add_arguments_sets"] +__all__ = ["run_sets_train", "run_sets_predict", "run_sets_infer", "add_arguments_sets"] def add_arguments_sets(parser: ap.ArgumentParser): @@ -11,3 +12,4 @@ def add_arguments_sets(parser: ap.ArgumentParser): subparsers = parser.add_subparsers(required=True, dest="subcommand") add_arguments_train(subparsers.add_parser("train")) add_arguments_predict(subparsers.add_parser("predict")) + add_arguments_infer(subparsers.add_parser("infer")) diff --git a/src/state_sets/_cli/_sets/_infer.py b/src/state_sets/_cli/_sets/_infer.py index 2617151d..60406c84 100644 --- a/src/state_sets/_cli/_sets/_infer.py +++ b/src/state_sets/_cli/_sets/_infer.py @@ -4,8 +4,16 @@ import numpy as np import os import pandas as pd +from tqdm import tqdm +import yaml from ...sets.models.pert_sets import PertSetsPerturbationModel +from cell_load.data_modules import PerturbationDataModule + +# state-sets sets infer --output_dir /home/aadduri/state-sets/test/ --checkpoint last.ckpt --adata /home/aadduri/state-sets/test/adata.h5ad --pert_col gene +# state-sets sets infer --output_dir /home/dhruvgautam/state-sets/test/ --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_llama_21712320_filtered_cs32_pretrained/hepg2/checkpoints/step=44000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene + +# state-sets sets infer --output /home/dhruvgautam/state-sets/test/ --output_dir /large_storage/ctc/userspace/aadduri/preprint/replogle_state_proper_cs32_sm/hepg2 --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_state_proper_cs32_sm/hepg2/checkpoints/step=48000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene --embed_key X_vci_1.5.2_4 def add_arguments_infer(parser: argparse.ArgumentParser): parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)") @@ -13,8 +21,10 @@ def add_arguments_infer(parser: argparse.ArgumentParser): parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features") parser.add_argument("--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels") parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") + parser.add_argument("--output_dir", type=str, required=True, help="Path to the output_dir containing the config.yaml file that was saved during training.") parser.add_argument("--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)") parser.add_argument("--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)") + parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") def run_sets_infer(args): @@ -22,6 +32,32 @@ def run_sets_infer(args): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + def load_config(cfg_path: str) -> dict: + """Load config from the YAML file that was dumped during training.""" + if not os.path.exists(cfg_path): + raise FileNotFoundError(f"Could not find config file: {cfg_path}") + with open(cfg_path, "r") as f: + cfg = yaml.safe_load(f) + return cfg + + # Load the config + config_path = os.path.join(args.output_dir, "config.yaml") + cfg = load_config(config_path) + logger.info(f"Loaded config from {config_path}") + + # Find run output directory & load data module + run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}") + data_module = PerturbationDataModule.load_state(data_module_path) + data_module.setup(stage="test") + logger.info(f"Loaded data module from {data_module_path}") + + # Get perturbation dimensions and mapping from data module + var_dims = data_module.get_var_dims() + pert_dim = var_dims["pert_dim"] + # Load model logger.info(f"Loading model from checkpoint: {args.checkpoint}") model = PertSetsPerturbationModel.load_from_checkpoint(args.checkpoint) @@ -29,10 +65,6 @@ def run_sets_infer(args): cell_sentence_len = model.cell_sentence_len device = next(model.parameters()).device - # Use model's config for batch prep - pert_onehot_map = getattr(model, "pert_onehot_map", None) - pert_dim = model.pert_dim - # Load AnnData logger.info(f"Loading AnnData from: {args.adata}") adata = sc.read_h5ad(args.adata) @@ -57,38 +89,122 @@ def run_sets_infer(args): else: X = adata.X logger.info(f"Using adata.X as input features: shape {X.shape}") - X = torch.tensor(X, dtype=torch.float32).to(device) - # Prepare perturbation tensor using the model's map + # Prepare perturbation tensor using the data module's mapping pert_names = adata.obs[args.pert_col].values - pert_tensor = torch.zeros((len(pert_names), pert_dim), device=device) - if pert_onehot_map is not None: - for idx, name in enumerate(pert_names): - if name in pert_onehot_map: - pert_tensor[idx, pert_onehot_map[name]] = 1 + pert_tensor = torch.zeros((len(pert_names), pert_dim), device='cpu') # Keep on CPU initially + logger.info(f"Perturbation tensor shape: {pert_tensor.shape}") + + # Use data module's perturbation mapping + pert_onehot_map = data_module.pert_onehot_map + + # Debug: check what's available + logger.info(f"Data module has {len(pert_onehot_map)} perturbations in mapping") + logger.info(f"First 10 perturbations in data module: {list(pert_onehot_map.keys())[:10]}") + + unique_pert_names = sorted(set(pert_names)) + logger.info(f"AnnData has {len(unique_pert_names)} unique perturbations") + logger.info(f"First 10 perturbations in AnnData: {unique_pert_names[:10]}") + + # Check overlap + overlap = set(unique_pert_names) & set(pert_onehot_map.keys()) + logger.info(f"Overlap between AnnData and data module: {len(overlap)} perturbations") + if len(overlap) < len(unique_pert_names): + missing = set(unique_pert_names) - set(pert_onehot_map.keys()) + logger.warning(f"Missing perturbations: {list(missing)[:10]}") + + # Check if there's a control perturbation that might match + control_pert = data_module.get_control_pert() + logger.info(f"Control perturbation in data module: '{control_pert}'") + + matched_count = 0 + for idx, name in enumerate(pert_names): + if name in pert_onehot_map: + pert_tensor[idx] = pert_onehot_map[name] + matched_count += 1 + else: + # For now, use control perturbation as fallback + if control_pert in pert_onehot_map: + pert_tensor[idx] = pert_onehot_map[control_pert] else: - # Optionally handle unknown perturbations - pass - else: - # Fallback: build map from AnnData (not recommended for production) - unique_perts = sorted(set(pert_names)) - pert_map = {name: i for i, name in enumerate(unique_perts)} - for idx, name in enumerate(pert_names): - pert_tensor[idx, pert_map[name]] = 1 - - # Prepare batch - batch = { - "ctrl_cell_emb": X, - "pert_emb": pert_tensor, - "pert_name": pert_names.tolist(), - "batch": torch.zeros((1, cell_sentence_len), device=device) - } - # when do we need the batch num things - - logger.info("Running inference...") + # Use first available perturbation as fallback + first_pert = list(pert_onehot_map.keys())[0] + pert_tensor[idx] = pert_onehot_map[first_pert] + + logger.info(f"Matched {matched_count} out of {len(pert_names)} perturbations") + + # Process in batches with progress bar + # Use cell_sentence_len as batch size since model expects this + n_samples = len(pert_names) + batch_size = cell_sentence_len # Model requires this exact batch size + n_batches = (n_samples + batch_size - 1) // batch_size # Ceiling division + + logger.info(f"Running inference on {n_samples} samples in {n_batches} batches of size {batch_size} (model's cell_sentence_len)...") + + all_preds = [] + with torch.no_grad(): - preds = model.forward(batch) - preds_np = preds.cpu().numpy() + progress_bar = tqdm(total=n_samples, desc="Processing samples", unit="samples") + + for batch_idx in range(n_batches): + start_idx = batch_idx * batch_size + end_idx = min(start_idx + batch_size, n_samples) + current_batch_size = end_idx - start_idx + + # Get batch data + X_batch = torch.tensor(X[start_idx:end_idx], dtype=torch.float32).to(device) + pert_batch = pert_tensor[start_idx:end_idx].to(device) + pert_names_batch = pert_names[start_idx:end_idx].tolist() + + # Pad the batch to cell_sentence_len if it's the last incomplete batch + if current_batch_size < cell_sentence_len: + # Pad with zeros for embeddings + padding_size = cell_sentence_len - current_batch_size + X_pad = torch.zeros((padding_size, X_batch.shape[1]), device=device) + X_batch = torch.cat([X_batch, X_pad], dim=0) + + # Pad perturbation tensor with control perturbation + pert_pad = torch.zeros((padding_size, pert_batch.shape[1]), device=device) + if control_pert in pert_onehot_map: + pert_pad[:] = pert_onehot_map[control_pert].to(device) + else: + pert_pad[:, 0] = 1 # Default to first perturbation + pert_batch = torch.cat([pert_batch, pert_pad], dim=0) + + # Extend perturbation names + pert_names_batch.extend([control_pert] * padding_size) + + # Prepare batch - use same format as working code + batch = { + "ctrl_cell_emb": X_batch, + "pert_emb": pert_batch, # Keep as 2D tensor + "pert_name": pert_names_batch, + "batch": torch.zeros((1, cell_sentence_len), device=device) # Use (1, cell_sentence_len) + } + + # Run inference on batch using padded=False like in working code + batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) + + # Extract predictions from the dictionary returned by predict_step + # Use gene decoder output if available, otherwise use latent predictions + if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: + # Use gene space predictions (from decoder) + pred_tensor = batch_preds["pert_cell_counts_preds"] + else: + # Use latent space predictions + pred_tensor = batch_preds["preds"] + + # Only keep predictions for the actual samples (not padding) + actual_preds = pred_tensor[:current_batch_size] + all_preds.append(actual_preds.cpu().numpy()) + + # Update progress bar + progress_bar.update(current_batch_size) + + progress_bar.close() + + # Concatenate all predictions + preds_np = np.concatenate(all_preds, axis=0) # Save predictions to AnnData pred_key = "model_preds" diff --git a/src/state_sets/_cli/_sets/_infer2.py b/src/state_sets/_cli/_sets/_infer2.py new file mode 100644 index 00000000..7522f657 --- /dev/null +++ b/src/state_sets/_cli/_sets/_infer2.py @@ -0,0 +1,150 @@ +import argparse +import scanpy as sc +import torch +import numpy as np +import os +import pandas as pd +from tqdm import tqdm + +from ...sets.models.pert_sets import PertSetsPerturbationModel + +# state-sets sets infer --output_dir /home/aadduri/state-sets/test/ --checkpoint last.ckpt --adata /home/aadduri/state-sets/test/adata.h5ad --pert_col gene +# state-sets sets infer --output_dir /home/dhruvgautam/state-sets/test/ --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_llama_21712320_filtered_cs32_pretrained/hepg2/checkpoints/step=44000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene + +# state-sets sets infer --output_dir /home/dhruvgautam/state-sets/test/ --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_llama_21712320_filtered_cs32_pretrained/hepg2/checkpoints/step=44000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene --embed_key X_vci_1.5.2_4 + +def add_arguments_infer(parser: argparse.ArgumentParser): + parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)") + parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") + parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features") + parser.add_argument("--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels") + parser.add_argument("--output_dir", type=str, default=None, help="Path to output AnnData file (.h5ad)") + parser.add_argument("--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)") + parser.add_argument("--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)") + parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") + + +def run_sets_infer(args): + import logging + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + # Load model + logger.info(f"Loading model from checkpoint: {args.checkpoint}") + model = PertSetsPerturbationModel.load_from_checkpoint(args.checkpoint) + model.eval() + cell_sentence_len = model.cell_sentence_len + device = next(model.parameters()).device + pert_dim = model.pert_dim + + logger.info(f"Using model's cell_sentence_len: {cell_sentence_len}") + logger.info(f"Using pert_dim: {pert_dim}") + + # Load AnnData + logger.info(f"Loading AnnData from: {args.adata}") + adata_full = sc.read_h5ad(args.adata) + + # Define control perturbations to look for + control_perts = ["DMSO_TF_24h", "non-targeting", "control", "DMSO"] + + # Find available control perturbation + available_perts = set(adata_full.obs[args.pert_col].unique()) + control_pert = None + for ctrl in control_perts: + if ctrl in available_perts: + control_pert = ctrl + logger.info(f"Using '{ctrl}' as control perturbation") + break + + if control_pert is None: + # Use the first available perturbation as fallback + control_pert = list(available_perts)[0] + logger.warning(f"No standard control found, using '{control_pert}' as control") + + # Get available cell types and select the most abundant one + if args.celltype_col is not None: + if args.celltype_col not in adata_full.obs: + raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.") + + if args.celltypes is not None: + celltypes = [ct.strip() for ct in args.celltypes.split(",")] + adata_full = adata_full[adata_full.obs[args.celltype_col].isin(celltypes)].copy() + logger.info(f"Filtered to specified cell types: {celltypes}") + + cell_type_counts = adata_full.obs[args.celltype_col].value_counts() + logger.info("Available cell types: %s", list(cell_type_counts.index)) + celltype1 = cell_type_counts.index[0] + logger.info(f"Selected cell type: {celltype1} ({cell_type_counts[celltype1]} available)") + + # Get control cells for this cell type + cells_type1 = adata_full[(adata_full.obs[args.pert_col] == control_pert) & + (adata_full.obs[args.celltype_col] == celltype1)].copy() + logger.info(f"Available control cells - {celltype1}: {cells_type1.n_obs}") + else: + # No cell type filtering, use all control cells + cells_type1 = adata_full[adata_full.obs[args.pert_col] == control_pert].copy() + logger.info(f"Available control cells: {cells_type1.n_obs}") + + # Use the model's actual cell_sentence_len + n_cells = cell_sentence_len + + if cells_type1.n_obs >= n_cells: + # Sample cells + idx1 = np.random.choice(cells_type1.n_obs, size=n_cells, replace=False) + sampled_cells = cells_type1[idx1].copy() + + logger.info(f"Sampled {sampled_cells.n_obs} cells for inference") + + # Extract embeddings based on available key + if args.embed_key in sampled_cells.obsm: + X_embed = torch.tensor(sampled_cells.obsm[args.embed_key], dtype=torch.float32).to(device) + logger.info(f"Using adata.obsm['{args.embed_key}'] as input features: shape {X_embed.shape}") + else: + X_data = sampled_cells.X.toarray() if hasattr(sampled_cells.X, 'toarray') else sampled_cells.X + X_embed = torch.tensor(X_data, dtype=torch.float32).to(device) + logger.info(f"Using adata.X as input features: shape {X_embed.shape}") + + # Create simple perturbation tensor - set first dimension to 1 for control + pert_tensor = torch.zeros((n_cells, pert_dim), device=device) + pert_tensor[:, 0] = 1 # Set first dimension to 1 for control perturbation + pert_names = [control_pert] * n_cells + + # Create batch dictionary + batch = { + "ctrl_cell_emb": X_embed, + "pert_emb": pert_tensor, + "pert_name": pert_names, + "batch": torch.zeros((1, cell_sentence_len), device=device) + } + + logger.info(f"Batch shapes - ctrl_cell_emb: {batch['ctrl_cell_emb'].shape}, pert_emb: {batch['pert_emb'].shape}") + logger.info(f"Running single forward pass with {n_cells} cells") + + # Single forward pass + with torch.no_grad(): + preds = model.forward(batch, padded=False) + + logger.info("Forward pass completed successfully") + preds_np = preds.cpu().numpy() + + # Save predictions to sampled cells + pred_key = "model_preds" + sampled_cells.obsm[pred_key] = preds_np + + else: + raise ValueError(f"Not enough control cells available. Need {n_cells}, but only have {cells_type1.n_obs}") + + # Save results + output_path = args.output_dir or args.adata.replace(".h5ad", "_with_preds.h5ad").replace(".h5", "_with_preds.h5ad") + sampled_cells.write_h5ad(output_path) + logger.info(f"Saved predictions to {output_path} (in adata.obsm['{pred_key}'])") + + +def main(): + parser = argparse.ArgumentParser(description="Run inference on AnnData with a trained model checkpoint.") + add_arguments_infer(parser) + args = parser.parse_args() + run_sets_infer(args) + +if __name__ == "__main__": + main() diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 92f5c473..8ae32f27 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -161,6 +161,8 @@ def __init__( self.output_dim = output_dim self.pert_dim = pert_dim self.batch_dim = batch_dim + self.gene_dim = gene_dim + self.hvg_dim = hvg_dim if kwargs.get("batch_encoder", False): self.batch_dim = batch_dim @@ -237,8 +239,40 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: Re-create the decoder using the exact hyper-parameters saved in the ckpt, so that parameter shapes match and load_state_dict succeeds. """ - self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] + if "decoder_cfg" in checkpoint["hyper_parameters"]: + self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] + else: + self.decoder_cfg = None self._build_decoder() + logger.info(f"DEBUG: output_space: {self.output_space}") + if self.gene_decoder is None: + gene_dim = self.hvg_dim if self.output_space == "gene" else self.gene_dim + logger.info(f"DEBUG: gene_dim: {gene_dim}") + if (self.embed_key and self.embed_key != "X_hvg" and self.output_space == "gene") or ( + self.embed_key and self.output_space == "all" + ): # we should be able to decode from hvg to all + logger.info(f"DEBUG: Creating gene_decoder, checking conditions...") + if gene_dim > 10000: + hidden_dims = [1024, 512, 256] + elif self.embed_key in ["X_vci_1.5.2", "X_vci_1.5.2_4"]: + hidden_dims = [1024, 1024, 512] + else: + if "DMSO_TF" in self.control_pert: + if self.residual_decoder: + hidden_dims = [2058, 2058, 2058, 2058, 2058] + else: + hidden_dims = [4096, 2048, 2048] + else: + hidden_dims = [1024, 1024, 512] # make this config + + self.gene_decoder = LatentToGeneDecoder( + latent_dim=self.output_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=self.dropout, + residual_decoder=self.residual_decoder, + ) + logger.info(f"Initialized gene decoder for embedding {self.embed_key} to gene space") def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: diff --git a/src/state_sets/sets/models/pert_sets.py b/src/state_sets/sets/models/pert_sets.py index a317b484..7c2c06d0 100644 --- a/src/state_sets/sets/models/pert_sets.py +++ b/src/state_sets/sets/models/pert_sets.py @@ -358,6 +358,8 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: # apply relu if specified and we output to HVG space is_gene_space = self.hparams["embed_key"] == "X_hvg" or self.hparams["embed_key"] is None + # logger.info(f"DEBUG: is_gene_space: {is_gene_space}") + # logger.info(f"DEBUG: self.gene_decoder: {self.gene_decoder}") if is_gene_space or self.gene_decoder is None: out_pred = self.relu(out_pred) From 49d5b2d5a39a67826392b1112abb144ac7e79d7e Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Thu, 19 Jun 2025 22:08:27 -0700 Subject: [PATCH 05/16] remove old infer --- src/state_sets/_cli/_sets/_infer2.py | 150 --------------------------- 1 file changed, 150 deletions(-) delete mode 100644 src/state_sets/_cli/_sets/_infer2.py diff --git a/src/state_sets/_cli/_sets/_infer2.py b/src/state_sets/_cli/_sets/_infer2.py deleted file mode 100644 index 7522f657..00000000 --- a/src/state_sets/_cli/_sets/_infer2.py +++ /dev/null @@ -1,150 +0,0 @@ -import argparse -import scanpy as sc -import torch -import numpy as np -import os -import pandas as pd -from tqdm import tqdm - -from ...sets.models.pert_sets import PertSetsPerturbationModel - -# state-sets sets infer --output_dir /home/aadduri/state-sets/test/ --checkpoint last.ckpt --adata /home/aadduri/state-sets/test/adata.h5ad --pert_col gene -# state-sets sets infer --output_dir /home/dhruvgautam/state-sets/test/ --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_llama_21712320_filtered_cs32_pretrained/hepg2/checkpoints/step=44000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene - -# state-sets sets infer --output_dir /home/dhruvgautam/state-sets/test/ --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_llama_21712320_filtered_cs32_pretrained/hepg2/checkpoints/step=44000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene --embed_key X_vci_1.5.2_4 - -def add_arguments_infer(parser: argparse.ArgumentParser): - parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)") - parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") - parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features") - parser.add_argument("--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels") - parser.add_argument("--output_dir", type=str, default=None, help="Path to output AnnData file (.h5ad)") - parser.add_argument("--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)") - parser.add_argument("--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)") - parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") - - -def run_sets_infer(args): - import logging - logging.basicConfig(level=logging.INFO) - logger = logging.getLogger(__name__) - - # Load model - logger.info(f"Loading model from checkpoint: {args.checkpoint}") - model = PertSetsPerturbationModel.load_from_checkpoint(args.checkpoint) - model.eval() - cell_sentence_len = model.cell_sentence_len - device = next(model.parameters()).device - pert_dim = model.pert_dim - - logger.info(f"Using model's cell_sentence_len: {cell_sentence_len}") - logger.info(f"Using pert_dim: {pert_dim}") - - # Load AnnData - logger.info(f"Loading AnnData from: {args.adata}") - adata_full = sc.read_h5ad(args.adata) - - # Define control perturbations to look for - control_perts = ["DMSO_TF_24h", "non-targeting", "control", "DMSO"] - - # Find available control perturbation - available_perts = set(adata_full.obs[args.pert_col].unique()) - control_pert = None - for ctrl in control_perts: - if ctrl in available_perts: - control_pert = ctrl - logger.info(f"Using '{ctrl}' as control perturbation") - break - - if control_pert is None: - # Use the first available perturbation as fallback - control_pert = list(available_perts)[0] - logger.warning(f"No standard control found, using '{control_pert}' as control") - - # Get available cell types and select the most abundant one - if args.celltype_col is not None: - if args.celltype_col not in adata_full.obs: - raise ValueError(f"Column '{args.celltype_col}' not found in adata.obs.") - - if args.celltypes is not None: - celltypes = [ct.strip() for ct in args.celltypes.split(",")] - adata_full = adata_full[adata_full.obs[args.celltype_col].isin(celltypes)].copy() - logger.info(f"Filtered to specified cell types: {celltypes}") - - cell_type_counts = adata_full.obs[args.celltype_col].value_counts() - logger.info("Available cell types: %s", list(cell_type_counts.index)) - celltype1 = cell_type_counts.index[0] - logger.info(f"Selected cell type: {celltype1} ({cell_type_counts[celltype1]} available)") - - # Get control cells for this cell type - cells_type1 = adata_full[(adata_full.obs[args.pert_col] == control_pert) & - (adata_full.obs[args.celltype_col] == celltype1)].copy() - logger.info(f"Available control cells - {celltype1}: {cells_type1.n_obs}") - else: - # No cell type filtering, use all control cells - cells_type1 = adata_full[adata_full.obs[args.pert_col] == control_pert].copy() - logger.info(f"Available control cells: {cells_type1.n_obs}") - - # Use the model's actual cell_sentence_len - n_cells = cell_sentence_len - - if cells_type1.n_obs >= n_cells: - # Sample cells - idx1 = np.random.choice(cells_type1.n_obs, size=n_cells, replace=False) - sampled_cells = cells_type1[idx1].copy() - - logger.info(f"Sampled {sampled_cells.n_obs} cells for inference") - - # Extract embeddings based on available key - if args.embed_key in sampled_cells.obsm: - X_embed = torch.tensor(sampled_cells.obsm[args.embed_key], dtype=torch.float32).to(device) - logger.info(f"Using adata.obsm['{args.embed_key}'] as input features: shape {X_embed.shape}") - else: - X_data = sampled_cells.X.toarray() if hasattr(sampled_cells.X, 'toarray') else sampled_cells.X - X_embed = torch.tensor(X_data, dtype=torch.float32).to(device) - logger.info(f"Using adata.X as input features: shape {X_embed.shape}") - - # Create simple perturbation tensor - set first dimension to 1 for control - pert_tensor = torch.zeros((n_cells, pert_dim), device=device) - pert_tensor[:, 0] = 1 # Set first dimension to 1 for control perturbation - pert_names = [control_pert] * n_cells - - # Create batch dictionary - batch = { - "ctrl_cell_emb": X_embed, - "pert_emb": pert_tensor, - "pert_name": pert_names, - "batch": torch.zeros((1, cell_sentence_len), device=device) - } - - logger.info(f"Batch shapes - ctrl_cell_emb: {batch['ctrl_cell_emb'].shape}, pert_emb: {batch['pert_emb'].shape}") - logger.info(f"Running single forward pass with {n_cells} cells") - - # Single forward pass - with torch.no_grad(): - preds = model.forward(batch, padded=False) - - logger.info("Forward pass completed successfully") - preds_np = preds.cpu().numpy() - - # Save predictions to sampled cells - pred_key = "model_preds" - sampled_cells.obsm[pred_key] = preds_np - - else: - raise ValueError(f"Not enough control cells available. Need {n_cells}, but only have {cells_type1.n_obs}") - - # Save results - output_path = args.output_dir or args.adata.replace(".h5ad", "_with_preds.h5ad").replace(".h5", "_with_preds.h5ad") - sampled_cells.write_h5ad(output_path) - logger.info(f"Saved predictions to {output_path} (in adata.obsm['{pred_key}'])") - - -def main(): - parser = argparse.ArgumentParser(description="Run inference on AnnData with a trained model checkpoint.") - add_arguments_infer(parser) - args = parser.parse_args() - run_sets_infer(args) - -if __name__ == "__main__": - main() From 8d8a075fdf3683b09ba88a15facd661c9031634e Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Fri, 20 Jun 2025 00:41:37 -0700 Subject: [PATCH 06/16] train working i think --- src/state_sets/_cli/_sets/_train.py | 7 +++- src/state_sets/sets/models/base.py | 63 +++++++++++++++-------------- 2 files changed, 38 insertions(+), 32 deletions(-) diff --git a/src/state_sets/_cli/_sets/_train.py b/src/state_sets/_cli/_sets/_train.py index 586750cd..093c5827 100644 --- a/src/state_sets/_cli/_sets/_train.py +++ b/src/state_sets/_cli/_sets/_train.py @@ -117,8 +117,11 @@ def run_sets_train(cfg: DictConfig): data_module.setup(stage="fit") var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} - gene_dim = var_dims.get("gene_dim", 5000) # fallback if key missing - latent_dim = cfg["model"]["kwargs"]["output_dim"] # same as model.output_dim + if cfg["data"]["kwargs"]["output_space"] == "gene": + gene_dim = var_dims.get("hvg_dim", 2000) # fallback if key missing + else: + gene_dim = var_dims.get("gene_dim", 2000) # fallback if key missing + latent_dim = var_dims["output_dim"] # same as model.output_dim hidden_dims = cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]) decoder_cfg = dict( diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 8ae32f27..13dee2c5 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -241,38 +241,41 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: """ if "decoder_cfg" in checkpoint["hyper_parameters"]: self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] + self._build_decoder() + logger.info(f"Loaded decoder from checkpoint decoder_cfg: {self.decoder_cfg}") else: - self.decoder_cfg = None - self._build_decoder() - logger.info(f"DEBUG: output_space: {self.output_space}") - if self.gene_decoder is None: - gene_dim = self.hvg_dim if self.output_space == "gene" else self.gene_dim - logger.info(f"DEBUG: gene_dim: {gene_dim}") - if (self.embed_key and self.embed_key != "X_hvg" and self.output_space == "gene") or ( - self.embed_key and self.output_space == "all" - ): # we should be able to decode from hvg to all - logger.info(f"DEBUG: Creating gene_decoder, checking conditions...") - if gene_dim > 10000: - hidden_dims = [1024, 512, 256] - elif self.embed_key in ["X_vci_1.5.2", "X_vci_1.5.2_4"]: - hidden_dims = [1024, 1024, 512] - else: - if "DMSO_TF" in self.control_pert: - if self.residual_decoder: - hidden_dims = [2058, 2058, 2058, 2058, 2058] - else: - hidden_dims = [4096, 2048, 2048] + # Only fall back to old logic if no decoder_cfg was saved + self.decoder_cfg = None + self._build_decoder() + logger.info(f"DEBUG: output_space: {self.output_space}") + if self.gene_decoder is None: + gene_dim = self.hvg_dim if self.output_space == "gene" else self.gene_dim + logger.info(f"DEBUG: gene_dim: {gene_dim}") + if (self.embed_key and self.embed_key != "X_hvg" and self.output_space == "gene") or ( + self.embed_key and self.output_space == "all" + ): # we should be able to decode from hvg to all + logger.info(f"DEBUG: Creating gene_decoder, checking conditions...") + if gene_dim > 10000: + hidden_dims = [1024, 512, 256] + elif self.embed_key in ["X_vci_1.5.2", "X_vci_1.5.2_4"]: + hidden_dims = [1024, 1024, 512] else: - hidden_dims = [1024, 1024, 512] # make this config - - self.gene_decoder = LatentToGeneDecoder( - latent_dim=self.output_dim, - gene_dim=gene_dim, - hidden_dims=hidden_dims, - dropout=self.dropout, - residual_decoder=self.residual_decoder, - ) - logger.info(f"Initialized gene decoder for embedding {self.embed_key} to gene space") + if "DMSO_TF" in self.control_pert: + if self.residual_decoder: + hidden_dims = [2058, 2058, 2058, 2058, 2058] + else: + hidden_dims = [4096, 2048, 2048] + else: + hidden_dims = [1024, 1024, 512] # make this config + + self.gene_decoder = LatentToGeneDecoder( + latent_dim=self.output_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=self.dropout, + residual_decoder=self.residual_decoder, + ) + logger.info(f"Initialized gene decoder for embedding {self.embed_key} to gene space") def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: From 301ac95b9647b4539fea0d1a0b15af66c512d295 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Fri, 20 Jun 2025 01:19:10 -0700 Subject: [PATCH 07/16] merge --- src/state_sets/sets/models/base.py | 36 ------------------------------ 1 file changed, 36 deletions(-) diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 13dee2c5..8da26fb8 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -182,42 +182,6 @@ def __init__( self.lr = lr self.loss_fn = get_loss_fn(loss_fn) - # this will either decode to hvg space if output space is a gene, - # or to transcriptome space if output space is all. done this way to maintain - # backwards compatibility with the old models - self.gene_decoder = None - gene_dim = hvg_dim if output_space == "gene" else gene_dim - if (embed_key and embed_key != "X_hvg" and output_space == "gene") or ( - embed_key and output_space == "all" - ): # we should be able to decode from hvg to all - if gene_dim > 10000: - hidden_dims = [1024, 512, 256] - else: - if "DMSO_TF" in self.control_pert: - if self.residual_decoder: - hidden_dims = [2058, 2058, 2058, 2058, 2058] - else: - hidden_dims = [4096, 2048, 2048] - elif "PBS" in self.control_pert: - hidden_dims = [2048, 1024, 1024] - else: - if "DMSO_TF" in self.control_pert: - if self.residual_decoder: - hidden_dims = [2058, 2058, 2058, 2058, 2058] - else: - hidden_dims = [4096, 2048, 2048] - else: - hidden_dims = [1024, 1024, 512] # make this config - - self.gene_decoder = LatentToGeneDecoder( - latent_dim=self.output_dim, - gene_dim=gene_dim, - hidden_dims=hidden_dims, - dropout=dropout, - residual_decoder=self.residual_decoder, - ) - logger.info(f"Initialized gene decoder for embedding {embed_key} to gene space") - def transfer_batch_to_device(self, batch, device, dataloader_idx: int): return {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} From 46a0eab4b588a18604522f1022905e31df02bc02 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Fri, 20 Jun 2025 11:36:50 -0700 Subject: [PATCH 08/16] restart fix --- src/state_sets/configs/model/tahoe_decoder_test.yaml | 8 ++++---- src/state_sets/sets/models/base.py | 3 ++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/state_sets/configs/model/tahoe_decoder_test.yaml b/src/state_sets/configs/model/tahoe_decoder_test.yaml index 1a0c985d..680e1812 100644 --- a/src/state_sets/configs/model/tahoe_decoder_test.yaml +++ b/src/state_sets/configs/model/tahoe_decoder_test.yaml @@ -6,7 +6,7 @@ kwargs: cell_set_len: 512 decoder_hidden_dims: [2048, 2048, 2048] blur: 0.05 - hidden_dim: 1488 # hidden dimension going into the transformer backbone + hidden_dim: 696 # hidden dimension going into the transformer backbone loss: energy confidence_head: False n_encoder_layers: 4 @@ -27,11 +27,11 @@ kwargs: transformer_backbone_kwargs: max_position_embeddings: ${model.kwargs.cell_set_len} hidden_size: ${model.kwargs.hidden_dim} - intermediate_size: 5952 - num_hidden_layers: 6 + intermediate_size: 2784 + num_hidden_layers: 8 num_attention_heads: 12 num_key_value_heads: 12 - head_dim: 124 + head_dim: 58 use_cache: false attention_dropout: 0.0 hidden_dropout: 0.0 diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 8da26fb8..8b522b47 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -181,6 +181,7 @@ def __init__( self.dropout = dropout self.lr = lr self.loss_fn = get_loss_fn(loss_fn) + self._build_decoder() def transfer_batch_to_device(self, batch, device, dataloader_idx: int): return {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()} @@ -205,7 +206,7 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: """ if "decoder_cfg" in checkpoint["hyper_parameters"]: self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] - self._build_decoder() + self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg) logger.info(f"Loaded decoder from checkpoint decoder_cfg: {self.decoder_cfg}") else: # Only fall back to old logic if no decoder_cfg was saved From 38f6a2e58beca5108581dcc9014a342e5b1b58e6 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Fri, 20 Jun 2025 14:38:03 -0700 Subject: [PATCH 09/16] train fixed --- src/state_sets/_cli/_sets/_train.py | 35 ++++++++++++++++++++++++++--- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/src/state_sets/_cli/_sets/_train.py b/src/state_sets/_cli/_sets/_train.py index 093c5827..ed41f801 100644 --- a/src/state_sets/_cli/_sets/_train.py +++ b/src/state_sets/_cli/_sets/_train.py @@ -109,12 +109,29 @@ def run_sets_train(cfg: DictConfig): batch_size=cfg["training"]["batch_size"], cell_sentence_len=sentence_len, ) + with open(join(run_output_dir, "data_module.torch"), "wb") as f: # TODO-Abhi: only save necessary data data_module.save_state(f) data_module.setup(stage="fit") + dl = data_module.train_dataloader() + print("num_workers:", dl.num_workers) + print("batch size:", dl.batch_size) + # # Test loading a single batch to identify potential data loading issues + # try: + # print("DEBUG: Testing data loading with one batch...") + # train_loader = data_module.train_dataloader() + # first_batch = next(iter(train_loader)) + # print(f"DEBUG: Successfully loaded first batch. Batch keys: {list(first_batch.keys()) if isinstance(first_batch, dict) else 'Not a dict'}") + # if isinstance(first_batch, dict): + # for key, value in first_batch.items(): + # if hasattr(value, 'shape'): + # print(f"DEBUG: {key} shape: {value.shape}") + # except Exception as e: + # print(f"DEBUG: Error loading first batch: {e}") + # print("DEBUG: This might be the source of the hanging issue") var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} if cfg["data"]["kwargs"]["output_space"] == "gene": @@ -135,9 +152,6 @@ def run_sets_train(cfg: DictConfig): # tuck it into the kwargs that will reach the LightningModule cfg["model"]["kwargs"]["decoder_cfg"] = decoder_cfg - cfg["data"]["kwargs"]["n_perts"] = len(data_module.pert_onehot_map) - cfg["model"]["kwargs"]["pert_onehot_map"] = data_module.pert_onehot_map - if cfg["model"]["name"].lower() in ["cpa", "scvi"] or cfg["model"]["name"].lower().startswith("scgpt"): cfg["model"]["kwargs"]["n_cell_types"] = len(data_module.celltype_onehot_map) cfg["model"]["kwargs"]["n_perts"] = len(data_module.pert_onehot_map) @@ -152,6 +166,8 @@ def run_sets_train(cfg: DictConfig): data_module.get_var_dims(), ) + print(f"DEBUG: Model created. Estimated params size: {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3:.2f} GB") + # print(f"DEBUG: Num workers: {data_module.train_dataloader().num_workers}") # Set up logging loggers = get_loggers( output_dir=cfg["output_dir"], @@ -220,7 +236,9 @@ def run_sets_train(cfg: DictConfig): del trainer_kwargs["max_steps"] # Build trainer + print(f"DEBUG: Building trainer with kwargs: {trainer_kwargs}") trainer = pl.Trainer(**trainer_kwargs) + print("DEBUG: Trainer built successfully") # Load checkpoint if exists checkpoint_path = join(ckpt_callbacks[0].dirpath, "last.ckpt") @@ -229,12 +247,17 @@ def run_sets_train(cfg: DictConfig): else: logging.info(f"!! Resuming training from {checkpoint_path} !!") + print(f"DEBUG: Model device: {next(model.parameters()).device}") + print(f"DEBUG: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") + print(f"DEBUG: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") + logger.info("Starting trainer fit.") # if a checkpoint does not exist, start with the provided checkpoint # this is mainly used for pretrain -> finetune workflows manual_init = cfg["model"]["kwargs"].get("init_from", None) if checkpoint_path is None and manual_init is not None: + print(f"DEBUG: Loading manual checkpoint from {manual_init}") checkpoint_path = manual_init device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(checkpoint_path, map_location=device) @@ -275,6 +298,7 @@ def run_sets_train(cfg: DictConfig): # Load the filtered state dict model.load_state_dict(filtered_state, strict=False) + print("DEBUG: About to call trainer.fit() with manual checkpoint...") # Train - for clarity we pass None trainer.fit( @@ -282,13 +306,18 @@ def run_sets_train(cfg: DictConfig): datamodule=data_module, ckpt_path=None, ) + print("DEBUG: trainer.fit() completed with manual checkpoint") else: + print(f"DEBUG: About to call trainer.fit() with checkpoint_path={checkpoint_path}") # Train trainer.fit( model, datamodule=data_module, ckpt_path=checkpoint_path, ) + print("DEBUG: trainer.fit() completed") + + print("DEBUG: Training completed, saving final checkpoint...") # at this point if checkpoint_path does not exist, manually create one checkpoint_path = join(ckpt_callbacks[0].dirpath, "final.ckpt") From b8dd2d3d9c5b76c984800fc9b7cfc87b651c6c15 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Fri, 20 Jun 2025 14:55:13 -0700 Subject: [PATCH 10/16] readme --- README.md | 5 +++++ src/state_sets/_cli/_sets/_infer.py | 5 ----- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index e4d9d77f..2ee25d6f 100644 --- a/README.md +++ b/README.md @@ -60,6 +60,11 @@ An example evaluation command for a sets model: state-sets sets predict --output_dir /home/aadduri/state-sets/test/ --checkpoint last.ckpt ``` +An example inference command for a sets model: +```bash +state-sets sets infer --output /home/dhruvgautam/state-sets/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/checkpoints/last.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg +``` + The toml files should be setup to define perturbation splits, if running fewshot experiments. Here are some examples: ```toml diff --git a/src/state_sets/_cli/_sets/_infer.py b/src/state_sets/_cli/_sets/_infer.py index 60406c84..5ec61f61 100644 --- a/src/state_sets/_cli/_sets/_infer.py +++ b/src/state_sets/_cli/_sets/_infer.py @@ -10,11 +10,6 @@ from ...sets.models.pert_sets import PertSetsPerturbationModel from cell_load.data_modules import PerturbationDataModule -# state-sets sets infer --output_dir /home/aadduri/state-sets/test/ --checkpoint last.ckpt --adata /home/aadduri/state-sets/test/adata.h5ad --pert_col gene -# state-sets sets infer --output_dir /home/dhruvgautam/state-sets/test/ --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_llama_21712320_filtered_cs32_pretrained/hepg2/checkpoints/step=44000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene - -# state-sets sets infer --output /home/dhruvgautam/state-sets/test/ --output_dir /large_storage/ctc/userspace/aadduri/preprint/replogle_state_proper_cs32_sm/hepg2 --checkpoint /large_storage/ctc/userspace/aadduri/preprint/replogle_state_proper_cs32_sm/hepg2/checkpoints/step=48000.ckpt --adata /large_storage/ctc/ML/state_sets/replogle/processed.h5 --pert_col gene --embed_key X_vci_1.5.2_4 - def add_arguments_infer(parser: argparse.ArgumentParser): parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)") parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") From c4356e41141c6b649d5331925895cced570357da Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Fri, 20 Jun 2025 14:55:29 -0700 Subject: [PATCH 11/16] space --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 2ee25d6f..4377ca2b 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ state-sets sets predict --output_dir /home/aadduri/state-sets/test/ --checkpoint ``` An example inference command for a sets model: + ```bash state-sets sets infer --output /home/dhruvgautam/state-sets/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/checkpoints/last.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg ``` From 0c68252ecabb5f0c3c1c71117452be0705a6ebd4 Mon Sep 17 00:00:00 2001 From: dhruvgautam Date: Sat, 21 Jun 2025 12:06:29 -0700 Subject: [PATCH 12/16] changes --- src/state_sets/_cli/_sets/_infer.py | 73 +++++++++++++++++------------ src/state_sets/_cli/_sets/_train.py | 62 ++++++++++-------------- src/state_sets/sets/models/base.py | 3 +- 3 files changed, 69 insertions(+), 69 deletions(-) diff --git a/src/state_sets/_cli/_sets/_infer.py b/src/state_sets/_cli/_sets/_infer.py index 5ec61f61..b79faf80 100644 --- a/src/state_sets/_cli/_sets/_infer.py +++ b/src/state_sets/_cli/_sets/_infer.py @@ -10,20 +10,33 @@ from ...sets.models.pert_sets import PertSetsPerturbationModel from cell_load.data_modules import PerturbationDataModule + def add_arguments_infer(parser: argparse.ArgumentParser): parser.add_argument("--checkpoint", type=str, required=True, help="Path to model checkpoint (.ckpt)") parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") parser.add_argument("--embed_key", type=str, default="X_hvg", help="Key in adata.obsm for input features") - parser.add_argument("--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels") + parser.add_argument( + "--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels" + ) parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") - parser.add_argument("--output_dir", type=str, required=True, help="Path to the output_dir containing the config.yaml file that was saved during training.") - parser.add_argument("--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)") - parser.add_argument("--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Path to the output_dir containing the config.yaml file that was saved during training.", + ) + parser.add_argument( + "--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)" + ) + parser.add_argument( + "--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)" + ) parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") def run_sets_infer(args): import logging + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -87,31 +100,30 @@ def load_config(cfg_path: str) -> dict: # Prepare perturbation tensor using the data module's mapping pert_names = adata.obs[args.pert_col].values - pert_tensor = torch.zeros((len(pert_names), pert_dim), device='cpu') # Keep on CPU initially + pert_tensor = torch.zeros((len(pert_names), pert_dim), device="cpu") # Keep on CPU initially logger.info(f"Perturbation tensor shape: {pert_tensor.shape}") - + # Use data module's perturbation mapping pert_onehot_map = data_module.pert_onehot_map - - # Debug: check what's available + logger.info(f"Data module has {len(pert_onehot_map)} perturbations in mapping") logger.info(f"First 10 perturbations in data module: {list(pert_onehot_map.keys())[:10]}") - + unique_pert_names = sorted(set(pert_names)) logger.info(f"AnnData has {len(unique_pert_names)} unique perturbations") logger.info(f"First 10 perturbations in AnnData: {unique_pert_names[:10]}") - + # Check overlap overlap = set(unique_pert_names) & set(pert_onehot_map.keys()) logger.info(f"Overlap between AnnData and data module: {len(overlap)} perturbations") if len(overlap) < len(unique_pert_names): missing = set(unique_pert_names) - set(pert_onehot_map.keys()) logger.warning(f"Missing perturbations: {list(missing)[:10]}") - + # Check if there's a control perturbation that might match control_pert = data_module.get_control_pert() logger.info(f"Control perturbation in data module: '{control_pert}'") - + matched_count = 0 for idx, name in enumerate(pert_names): if name in pert_onehot_map: @@ -125,7 +137,7 @@ def load_config(cfg_path: str) -> dict: # Use first available perturbation as fallback first_pert = list(pert_onehot_map.keys())[0] pert_tensor[idx] = pert_onehot_map[first_pert] - + logger.info(f"Matched {matched_count} out of {len(pert_names)} perturbations") # Process in batches with progress bar @@ -133,31 +145,33 @@ def load_config(cfg_path: str) -> dict: n_samples = len(pert_names) batch_size = cell_sentence_len # Model requires this exact batch size n_batches = (n_samples + batch_size - 1) // batch_size # Ceiling division - - logger.info(f"Running inference on {n_samples} samples in {n_batches} batches of size {batch_size} (model's cell_sentence_len)...") - + + logger.info( + f"Running inference on {n_samples} samples in {n_batches} batches of size {batch_size} (model's cell_sentence_len)..." + ) + all_preds = [] - + with torch.no_grad(): progress_bar = tqdm(total=n_samples, desc="Processing samples", unit="samples") - + for batch_idx in range(n_batches): start_idx = batch_idx * batch_size end_idx = min(start_idx + batch_size, n_samples) current_batch_size = end_idx - start_idx - + # Get batch data X_batch = torch.tensor(X[start_idx:end_idx], dtype=torch.float32).to(device) pert_batch = pert_tensor[start_idx:end_idx].to(device) pert_names_batch = pert_names[start_idx:end_idx].tolist() - + # Pad the batch to cell_sentence_len if it's the last incomplete batch if current_batch_size < cell_sentence_len: # Pad with zeros for embeddings padding_size = cell_sentence_len - current_batch_size X_pad = torch.zeros((padding_size, X_batch.shape[1]), device=device) X_batch = torch.cat([X_batch, X_pad], dim=0) - + # Pad perturbation tensor with control perturbation pert_pad = torch.zeros((padding_size, pert_batch.shape[1]), device=device) if control_pert in pert_onehot_map: @@ -165,21 +179,21 @@ def load_config(cfg_path: str) -> dict: else: pert_pad[:, 0] = 1 # Default to first perturbation pert_batch = torch.cat([pert_batch, pert_pad], dim=0) - + # Extend perturbation names pert_names_batch.extend([control_pert] * padding_size) - + # Prepare batch - use same format as working code batch = { "ctrl_cell_emb": X_batch, "pert_emb": pert_batch, # Keep as 2D tensor "pert_name": pert_names_batch, - "batch": torch.zeros((1, cell_sentence_len), device=device) # Use (1, cell_sentence_len) + "batch": torch.zeros((1, cell_sentence_len), device=device), # Use (1, cell_sentence_len) } - + # Run inference on batch using padded=False like in working code batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False) - + # Extract predictions from the dictionary returned by predict_step # Use gene decoder output if available, otherwise use latent predictions if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None: @@ -188,14 +202,14 @@ def load_config(cfg_path: str) -> dict: else: # Use latent space predictions pred_tensor = batch_preds["preds"] - + # Only keep predictions for the actual samples (not padding) actual_preds = pred_tensor[:current_batch_size] all_preds.append(actual_preds.cpu().numpy()) - + # Update progress bar progress_bar.update(current_batch_size) - + progress_bar.close() # Concatenate all predictions @@ -215,5 +229,6 @@ def main(): args = parser.parse_args() run_sets_infer(args) + if __name__ == "__main__": main() diff --git a/src/state_sets/_cli/_sets/_train.py b/src/state_sets/_cli/_sets/_train.py index ed41f801..1a7cf50e 100644 --- a/src/state_sets/_cli/_sets/_train.py +++ b/src/state_sets/_cli/_sets/_train.py @@ -109,7 +109,6 @@ def run_sets_train(cfg: DictConfig): batch_size=cfg["training"]["batch_size"], cell_sentence_len=sentence_len, ) - with open(join(run_output_dir, "data_module.torch"), "wb") as f: # TODO-Abhi: only save necessary data @@ -119,34 +118,21 @@ def run_sets_train(cfg: DictConfig): dl = data_module.train_dataloader() print("num_workers:", dl.num_workers) print("batch size:", dl.batch_size) - # # Test loading a single batch to identify potential data loading issues - # try: - # print("DEBUG: Testing data loading with one batch...") - # train_loader = data_module.train_dataloader() - # first_batch = next(iter(train_loader)) - # print(f"DEBUG: Successfully loaded first batch. Batch keys: {list(first_batch.keys()) if isinstance(first_batch, dict) else 'Not a dict'}") - # if isinstance(first_batch, dict): - # for key, value in first_batch.items(): - # if hasattr(value, 'shape'): - # print(f"DEBUG: {key} shape: {value.shape}") - # except Exception as e: - # print(f"DEBUG: Error loading first batch: {e}") - # print("DEBUG: This might be the source of the hanging issue") - - var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} + + var_dims = data_module.get_var_dims() # {"gene_dim": …, "hvg_dim": …} if cfg["data"]["kwargs"]["output_space"] == "gene": - gene_dim = var_dims.get("hvg_dim", 2000) # fallback if key missing + gene_dim = var_dims.get("hvg_dim", 2000) # fallback if key missing else: - gene_dim = var_dims.get("gene_dim", 2000) # fallback if key missing - latent_dim = var_dims["output_dim"] # same as model.output_dim + gene_dim = var_dims.get("gene_dim", 2000) # fallback if key missing + latent_dim = var_dims["output_dim"] # same as model.output_dim hidden_dims = cfg["model"]["kwargs"].get("decoder_hidden_dims", [1024, 1024, 512]) decoder_cfg = dict( - latent_dim = latent_dim, - gene_dim = gene_dim, - hidden_dims = hidden_dims, - dropout = cfg["model"]["kwargs"].get("decoder_dropout", 0.1), - residual_decoder = cfg["model"]["kwargs"].get("residual_decoder", False), + latent_dim=latent_dim, + gene_dim=gene_dim, + hidden_dims=hidden_dims, + dropout=cfg["model"]["kwargs"].get("decoder_dropout", 0.1), + residual_decoder=cfg["model"]["kwargs"].get("residual_decoder", False), ) # tuck it into the kwargs that will reach the LightningModule @@ -166,9 +152,9 @@ def run_sets_train(cfg: DictConfig): data_module.get_var_dims(), ) - print(f"DEBUG: Model created. Estimated params size: {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3:.2f} GB") - # print(f"DEBUG: Num workers: {data_module.train_dataloader().num_workers}") - # Set up logging + print( + f"Model created. Estimated params size: {sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**3:.2f} GB" + ) loggers = get_loggers( output_dir=cfg["output_dir"], name=cfg["name"], @@ -236,9 +222,9 @@ def run_sets_train(cfg: DictConfig): del trainer_kwargs["max_steps"] # Build trainer - print(f"DEBUG: Building trainer with kwargs: {trainer_kwargs}") + print(f"Building trainer with kwargs: {trainer_kwargs}") trainer = pl.Trainer(**trainer_kwargs) - print("DEBUG: Trainer built successfully") + print("Trainer built successfully") # Load checkpoint if exists checkpoint_path = join(ckpt_callbacks[0].dirpath, "last.ckpt") @@ -247,9 +233,9 @@ def run_sets_train(cfg: DictConfig): else: logging.info(f"!! Resuming training from {checkpoint_path} !!") - print(f"DEBUG: Model device: {next(model.parameters()).device}") - print(f"DEBUG: CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") - print(f"DEBUG: CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") + print(f"Model device: {next(model.parameters()).device}") + print(f"CUDA memory allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB") + print(f"CUDA memory reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB") logger.info("Starting trainer fit.") @@ -257,7 +243,7 @@ def run_sets_train(cfg: DictConfig): # this is mainly used for pretrain -> finetune workflows manual_init = cfg["model"]["kwargs"].get("init_from", None) if checkpoint_path is None and manual_init is not None: - print(f"DEBUG: Loading manual checkpoint from {manual_init}") + print(f"Loading manual checkpoint from {manual_init}") checkpoint_path = manual_init device = torch.device("cuda" if torch.cuda.is_available() else "cpu") checkpoint = torch.load(checkpoint_path, map_location=device) @@ -298,7 +284,7 @@ def run_sets_train(cfg: DictConfig): # Load the filtered state dict model.load_state_dict(filtered_state, strict=False) - print("DEBUG: About to call trainer.fit() with manual checkpoint...") + print("About to call trainer.fit() with manual checkpoint...") # Train - for clarity we pass None trainer.fit( @@ -306,18 +292,18 @@ def run_sets_train(cfg: DictConfig): datamodule=data_module, ckpt_path=None, ) - print("DEBUG: trainer.fit() completed with manual checkpoint") + print("trainer.fit() completed with manual checkpoint") else: - print(f"DEBUG: About to call trainer.fit() with checkpoint_path={checkpoint_path}") + print(f"About to call trainer.fit() with checkpoint_path={checkpoint_path}") # Train trainer.fit( model, datamodule=data_module, ckpt_path=checkpoint_path, ) - print("DEBUG: trainer.fit() completed") + print("trainer.fit() completed") - print("DEBUG: Training completed, saving final checkpoint...") + print("Training completed, saving final checkpoint...") # at this point if checkpoint_path does not exist, manually create one checkpoint_path = join(ckpt_callbacks[0].dirpath, "final.ckpt") diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index 8b522b47..f95e2c1d 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -210,7 +210,7 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: logger.info(f"Loaded decoder from checkpoint decoder_cfg: {self.decoder_cfg}") else: # Only fall back to old logic if no decoder_cfg was saved - self.decoder_cfg = None + self.decoder_cfg = None self._build_decoder() logger.info(f"DEBUG: output_space: {self.output_space}") if self.gene_decoder is None: @@ -242,7 +242,6 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: ) logger.info(f"Initialized gene decoder for embedding {self.embed_key} to gene space") - def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> torch.Tensor: """Training step logic for both main model and decoder.""" # Get model predictions (in latent space) From fa46d24d2e684c10db2918a87d8de522958204bc Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Sat, 21 Jun 2025 21:10:37 -0700 Subject: [PATCH 13/16] version --- pyproject.toml | 4 ++-- src/state_sets/state/data/loader.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a88d814d..17d6ba82 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "state-sets" -version = "0.6.5" +version = "0.7.0" description = "Add your description here" readme = "README.md" authors = [] @@ -29,7 +29,7 @@ dev = ["ruff>=0.11.11", "vulture>=2.14"] [tool.uv.sources] cell-load = { git = "ssh://github.com/arcinstitute/cell-load.git" } -cell-eval = { git = "ssh://github.com/arcinstitute/cell-eval.git", branch = "pin_with_pr_curves" } +cell-eval = { git = "ssh://github.com/arcinstitute/cell-eval.git" } zclip = { git = "https://github.com/bluorion-com/ZClip.git" } [build-system] diff --git a/src/state_sets/state/data/loader.py b/src/state_sets/state/data/loader.py index 6dba874a..2d2f5b8a 100644 --- a/src/state_sets/state/data/loader.py +++ b/src/state_sets/state/data/loader.py @@ -204,6 +204,7 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) if (new_mapping == -1).all(): # probably it contains ensembl id's instead + assert "gene_name" in adata.var.keys() gene_names = adata.var["gene_name"].values new_mapping = np.array([global_pos.get(g, -1) for g in gene_names]) From d9b7632234df6f6cd35ded2370ea00d3cdc026d1 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Mon, 23 Jun 2025 01:04:55 -0700 Subject: [PATCH 14/16] changed state embed to use model folder ckpt for now --- src/state_sets/_cli/_sets/_predict.py | 1 + src/state_sets/_cli/_state/_embed.py | 29 +++++++++++++++++---------- src/state_sets/sets/models/base.py | 2 -- src/state_sets/state/data/loader.py | 21 +++++++++---------- src/state_sets/state/inference.py | 27 ++++++++++--------------- 5 files changed, 39 insertions(+), 41 deletions(-) diff --git a/src/state_sets/_cli/_sets/_predict.py b/src/state_sets/_cli/_sets/_predict.py index 08982062..8164dcba 100644 --- a/src/state_sets/_cli/_sets/_predict.py +++ b/src/state_sets/_cli/_sets/_predict.py @@ -366,6 +366,7 @@ def load_config(cfg_path: str) -> dict: outdir=results_dir, prefix=ct, pdex_kwargs=pdex_kwargs, + batch_size=2048, ) evaluator.compute( diff --git a/src/state_sets/_cli/_state/_embed.py b/src/state_sets/_cli/_state/_embed.py index 0e2a3df1..396999ce 100644 --- a/src/state_sets/_cli/_state/_embed.py +++ b/src/state_sets/_cli/_state/_embed.py @@ -1,16 +1,13 @@ import argparse as ap +import torch def add_arguments_embed(parser: ap.ArgumentParser): """Add arguments for state embedding CLI.""" - parser.add_argument("--checkpoint", required=True, help="Path to the model checkpoint file") - parser.add_argument("--config", required=True, help="Path to the model training config") + parser.add_argument("--model-folder", required=True, help="Path to the model checkpoint folder") parser.add_argument("--input", required=True, help="Path to input anndata file (h5ad)") parser.add_argument("--output", required=True, help="Path to output embedded anndata file (h5ad)") - parser.add_argument("--dataset-name", default="perturbation", help="Dataset name to be used in dataloader creation") - parser.add_argument("--gpu", action="store_true", help="Use GPU if available") - parser.add_argument("--filter", action="store_true", help="Filter gene set to our esm embeddings only.") - parser.add_argument("--embed-key", help="Name of key to store embeddings") + parser.add_argument("--embed-key", default="X_state", help="Name of key to store embeddings") def run_state_embed(args: ap.ArgumentParser): @@ -18,6 +15,7 @@ def run_state_embed(args: ap.ArgumentParser): Compute embeddings for an input anndata file using a pre-trained VCI model checkpoint. """ import os + import glob import logging from omegaconf import OmegaConf @@ -27,13 +25,23 @@ def run_state_embed(args: ap.ArgumentParser): from ...state.inference import Inference - # Load configuration - logger.info(f"Loading config from {args.config}") - conf = OmegaConf.load(args.config) + # look in the model folder with glob for *.ckpt, get the first one, and print it + model_files = glob.glob(os.path.join(args.model_folder, "*.ckpt")) + if not model_files: + logger.error(f"No model checkpoint found in {args.model_folder}") + raise FileNotFoundError(f"No model checkpoint found in {args.model_folder}") + args.checkpoint = model_files[0] + logger.info(f"Using model checkpoint: {args.checkpoint}") # Create inference object logger.info("Creating inference object") - inferer = Inference(conf) + embedding_file = os.path.join(args.model_folder, "protein_embeddings.pt") + protein_embeds = torch.load(embedding_file, weights_only=False, map_location="cpu") + + config_file = os.path.join(args.model_folder, "config.yaml") + conf = OmegaConf.load(config_file) + + inferer = Inference(cfg=conf, protein_embeds=protein_embeds) # Load model from checkpoint logger.info(f"Loading model from checkpoint: {args.checkpoint}") @@ -53,7 +61,6 @@ def run_state_embed(args: ap.ArgumentParser): input_adata_path=args.input, output_adata_path=args.output, emb_key=args.embed_key, - dataset_name=args.dataset_name, ) logger.info("Embedding computation completed successfully!") diff --git a/src/state_sets/sets/models/base.py b/src/state_sets/sets/models/base.py index f95e2c1d..50ab1d3f 100644 --- a/src/state_sets/sets/models/base.py +++ b/src/state_sets/sets/models/base.py @@ -222,8 +222,6 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: logger.info(f"DEBUG: Creating gene_decoder, checking conditions...") if gene_dim > 10000: hidden_dims = [1024, 512, 256] - elif self.embed_key in ["X_vci_1.5.2", "X_vci_1.5.2_4"]: - hidden_dims = [1024, 1024, 512] else: if "DMSO_TF" in self.control_pert: if self.residual_decoder: diff --git a/src/state_sets/state/data/loader.py b/src/state_sets/state/data/loader.py index 2d2f5b8a..1ef5832d 100644 --- a/src/state_sets/state/data/loader.py +++ b/src/state_sets/state/data/loader.py @@ -30,6 +30,7 @@ def create_dataloader( adata_name=None, shuffle=False, sentence_collator=None, + protein_embeds=None, ): """ Expected to be used for inference @@ -44,7 +45,7 @@ def create_dataloader( if data_dir: utils.get_dataset_cfg(cfg).data_dir = data_dir - dataset = FilteredGenesCounts(cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name) + dataset = FilteredGenesCounts(cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name, protein_embeds=protein_embeds) if sentence_collator is None: sentence_collator = VCIDatasetSentenceCollator( cfg, valid_gene_mask=dataset.valid_gene_index, ds_emb_mapping_inference=dataset.ds_emb_map, is_train=False @@ -170,21 +171,17 @@ def get_dim(self) -> Dict[str, int]: class FilteredGenesCounts(H5adSentenceDataset): - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None) -> None: + def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None, protein_embeds=None) -> None: super(FilteredGenesCounts, self).__init__(cfg, test, datasets, shape_dict, adata, adata_name) self.valid_gene_index = {} + self.protein_embeds = protein_embeds # make sure we get training datasets - _, self.datasets, self.shapes_dict, self.dataset_path_map, self.dataset_group_map = utils.get_shapes_dict( - "/home/aadduri/state/h5ad_all.csv" - ) + self.datasets = [] + self.shapes_dict = {} + self.ds_emb_map = {} emb_cfg = utils.get_embedding_cfg(self.cfg) - try: - self.ds_emb_map = torch.load(emb_cfg.ds_emb_mapping, weights_only=False) - except (FileNotFoundError, IOError): - self.ds_emb_map = {} - # for inference, let's make sure this dataset's valid mask is available if adata_name is not None: # append it to self.datasets @@ -192,7 +189,7 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, self.shapes_dict[adata_name] = adata.shape # compute its embedding‐index vector - esm_data = torch.load(emb_cfg.all_embeddings, weights_only=False) + esm_data = self.protein_embeds or torch.load(emb_cfg.all_embeddings, weights_only=False) valid_genes_list = list(esm_data.keys()) # make a gene→global‐index lookup global_pos = {g: i for i, g in enumerate(valid_genes_list)} @@ -212,7 +209,7 @@ def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, self.ds_emb_map[adata_name] = new_mapping if utils.get_embedding_cfg(self.cfg).ds_emb_mapping is not None: - esm_data = torch.load(utils.get_embedding_cfg(self.cfg)["all_embeddings"], weights_only=False) + esm_data = self.protein_embeds or torch.load(emb_cfg.ds_emb_mapping, weights_only=False) valid_genes_list = list(esm_data.keys()) for name in self.datasets: if not utils.is_valid_uuid( diff --git a/src/state_sets/state/inference.py b/src/state_sets/state/inference.py index 6bd905e3..5810bf81 100644 --- a/src/state_sets/state/inference.py +++ b/src/state_sets/state/inference.py @@ -18,10 +18,10 @@ class Inference: - def __init__(self, cfg): + def __init__(self, cfg=None, protein_embeds=None): self.model = None self.collator = None - self.protein_embeds = None + self.protein_embeds = protein_embeds self._vci_conf = cfg def __load_dataset_meta(self, adata_path): @@ -95,16 +95,18 @@ def load_model(self, checkpoint): # Load and initialize model for eval self.model = LitUCEModel.load_from_checkpoint( - checkpoint, strict=False, cfg=self._vci_conf - ) ### THIS IS THE LINE THAT FAILS - all_pe = get_embeddings(self._vci_conf) - all_pe.requires_grad = False + checkpoint, strict=False + ) + all_pe = self.protein_embeds or get_embeddings(self._vci_conf) + if isinstance(all_pe, dict): + all_pe = torch.vstack(list(all_pe.values())) self.model.pe_embedding = nn.Embedding.from_pretrained(all_pe) self.model.pe_embedding.to(self.model.device) self.model.binary_decoder.requires_grad = False self.model.eval() - self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) + if self.protein_embeds is None: + self.protein_embeds = torch.load(get_embedding_cfg(self._vci_conf).all_embeddings, weights_only=False) def init_from_model(self, model, protein_embeds=None): """ @@ -149,10 +151,11 @@ def encode_adata( dataloader = create_dataloader( self._vci_conf, adata=adata, - adata_name=dataset_name, + adata_name=dataset_name or "inference", shape_dict=shape_dict, data_dir=os.path.dirname(input_adata_path), shuffle=False, + protein_embeds=self.protein_embeds, ) all_embeddings = [] @@ -173,14 +176,6 @@ def encode_adata( adata.obsm[emb_key] = all_embeddings adata.write_h5ad(output_adata_path) - # This streaming approach was not working (h5 files could not be opened) - # for embeddings in tqdm(self.encode(dataloader), - # total=len(dataloader), - # desc='Encoding'): - # self._save_data(input_adata_path, output_adata_path, emb_key, embeddings) - - # return output_adata_path - def decode_from_file(self, adata_path, emb_key: str, read_depth=None, batch_size=64): adata = anndata.read_h5ad(adata_path) genes = adata.var.index From 4582ce37a73de18834e8178e7ad2e6346301abdc Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Mon, 23 Jun 2025 01:50:17 -0700 Subject: [PATCH 15/16] updated so inference works with no data dependency beyond checkpoints --- src/state_sets/state/data/loader.py | 8 +- src/state_sets/state/finetune_decoder.py | 4 +- src/state_sets/state/inference.py | 8 +- src/state_sets/state/nn/flash_transformer.py | 3 +- src/state_sets/state/nn/model.py | 93 +++----------------- src/state_sets/state/train/trainer.py | 6 +- 6 files changed, 25 insertions(+), 97 deletions(-) diff --git a/src/state_sets/state/data/loader.py b/src/state_sets/state/data/loader.py index 1ef5832d..40da2ec7 100644 --- a/src/state_sets/state/data/loader.py +++ b/src/state_sets/state/data/loader.py @@ -45,7 +45,9 @@ def create_dataloader( if data_dir: utils.get_dataset_cfg(cfg).data_dir = data_dir - dataset = FilteredGenesCounts(cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name, protein_embeds=protein_embeds) + dataset = FilteredGenesCounts( + cfg, datasets=datasets, shape_dict=shape_dict, adata=adata, adata_name=adata_name, protein_embeds=protein_embeds + ) if sentence_collator is None: sentence_collator = VCIDatasetSentenceCollator( cfg, valid_gene_mask=dataset.valid_gene_index, ds_emb_mapping_inference=dataset.ds_emb_map, is_train=False @@ -171,7 +173,9 @@ def get_dim(self) -> Dict[str, int]: class FilteredGenesCounts(H5adSentenceDataset): - def __init__(self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None, protein_embeds=None) -> None: + def __init__( + self, cfg, test=False, datasets=None, shape_dict=None, adata=None, adata_name=None, protein_embeds=None + ) -> None: super(FilteredGenesCounts, self).__init__(cfg, test, datasets, shape_dict, adata, adata_name) self.valid_gene_index = {} self.protein_embeds = protein_embeds diff --git a/src/state_sets/state/finetune_decoder.py b/src/state_sets/state/finetune_decoder.py index ef346ef7..5d5dc700 100644 --- a/src/state_sets/state/finetune_decoder.py +++ b/src/state_sets/state/finetune_decoder.py @@ -2,7 +2,7 @@ import torch from torch import nn -from vci.nn.model import LitUCEModel +from vci.nn.model import StateEmbeddingModel from vci.train.trainer import get_embeddings from vci.utils import get_embedding_cfg @@ -44,7 +44,7 @@ def load_model(self, checkpoint): # Import locally to avoid circular imports # Load and initialize model for eval - self.model = LitUCEModel.load_from_checkpoint(checkpoint, strict=False) + self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, strict=False) self.device = self.model.device # Load protein embeddings diff --git a/src/state_sets/state/inference.py b/src/state_sets/state/inference.py index 5810bf81..2f17dfb4 100644 --- a/src/state_sets/state/inference.py +++ b/src/state_sets/state/inference.py @@ -9,7 +9,7 @@ from tqdm import tqdm from torch import nn -from .nn.model import LitUCEModel +from .nn.model import StateEmbeddingModel from .train.trainer import get_embeddings from .data import create_dataloader from .utils import get_embedding_cfg @@ -94,9 +94,7 @@ def load_model(self, checkpoint): raise ValueError("Model already initialized") # Load and initialize model for eval - self.model = LitUCEModel.load_from_checkpoint( - checkpoint, strict=False - ) + self.model = StateEmbeddingModel.load_from_checkpoint(checkpoint, strict=False) all_pe = self.protein_embeds or get_embeddings(self._vci_conf) if isinstance(all_pe, dict): all_pe = torch.vstack(list(all_pe.values())) @@ -200,7 +198,7 @@ def decode_from_adata(self, adata, genes, emb_key: str, read_depth=None, batch_s task_counts = torch.full((cell_embeds_batch.shape[0],), read_depth, device=self.model.device) else: task_counts = None - merged_embs = LitUCEModel.resize_batch(cell_embeds_batch, gene_embeds, task_counts) + merged_embs = StateEmbeddingModel.resize_batch(cell_embeds_batch, gene_embeds, task_counts) logprobs_batch = self.model.binary_decoder(merged_embs) logprobs_batch = logprobs_batch.detach().cpu().numpy() yield logprobs_batch.squeeze() diff --git a/src/state_sets/state/nn/flash_transformer.py b/src/state_sets/state/nn/flash_transformer.py index cbb9e6b3..ba64b167 100644 --- a/src/state_sets/state/nn/flash_transformer.py +++ b/src/state_sets/state/nn/flash_transformer.py @@ -1,7 +1,6 @@ # File: vci/flash_transformer.py """ -This module implements a Transformer encoder layer that uses Flash Attention. -It provides a FlashTransformerEncoderLayer and a FlashTransformerEncoder. +This module implements a Transformer encoder layer. """ import torch diff --git a/src/state_sets/state/nn/model.py b/src/state_sets/state/nn/model.py index 86179870..dbbefb35 100644 --- a/src/state_sets/state/nn/model.py +++ b/src/state_sets/state/nn/model.py @@ -83,53 +83,11 @@ def forward(self, x: Tensor) -> Tensor: return self.dropout(x) -class RDAEncoder(nn.Module): - """ - Map a scalar read-depth d -> (mu, log_sigma). - latent_dim is whatever you used for z_count. - """ - - def __init__(self, latent_dim: int): - super().__init__() - self.latent_dim = latent_dim - self.net = nn.Sequential( - # nn.Linear(2, latent_dim, bias=True), - SkipBlock(self.latent_dim), - SkipBlock(self.latent_dim), - ) - - def forward(self, mu, sigma): - """ - Parameters - ---------- - mu : tensor or float -- shape (B,) or (B,1) or scalar - sigma : tensor or float -- same shape as `mu`; *must be ≥0* - - Returns - ------- - z_latent : tensor of shape (B, latent_dim) - Sampled from 𝒩(μ, σ²I) and processed by `self.net`. - """ - # Ensure tensors, move to same device as the network - mu = mu.view(-1, 1) - sigma = sigma.view(-1, 1) - - # Sample eps ~ N(0, I) and build z - eps = torch.randn(mu.size(0), self.latent_dim, device=mu.device) - z = mu.expand(-1, self.latent_dim) + sigma.expand(-1, self.latent_dim) * eps - - # z = torch.cat((mu, sigma), dim=1) # (B, 2) - - # Pass through two SkipBlocks (or identity if you prefer) - z = self.net(z) # (B, latent_dim) - return z - - def nanstd(x): return torch.sqrt(torch.nanmean(torch.pow(x - torch.nanmean(x, dim=-1).unsqueeze(-1), 2), dim=-1)) -class LitUCEModel(L.LightningModule): +class StateEmbeddingModel(L.LightningModule): def __init__( self, token_dim: int, @@ -195,11 +153,7 @@ def __init__( if compiled: self.decoder = torch.compile(self.decoder) - if self.cfg.model.get("sample_rda", False): - self.z_dim_rd = 128 - self.z_encoder = RDAEncoder(self.z_dim_rd) - else: - self.z_dim_rd = 1 if self.cfg.model.rda else 0 + self.z_dim_rd = 1 if self.cfg.model.rda else 0 self.z_dim_ds = 10 if self.cfg.model.get("dataset_correction", False) else 0 self.z_dim = self.z_dim_rd + self.z_dim_ds @@ -362,31 +316,22 @@ def _predict_exp_for_adata(self, adata, dataset_name, pert_col): _, _, _, emb, ds_emb = self._compute_embedding_for_batch(batch) # now decode from the embedding - if self.z_dim_rd > 1: - Y = batch[2].to(self.device) - nan_y = Y.masked_fill(Y == 0, float("nan"))[:, : self.cfg.dataset.P + self.cfg.dataset.N] - mu = torch.nanmean(nan_y, dim=1) if self.cfg.model.rda else None - sigma = nanstd(nan_y) if self.cfg.model.rda else None - sampled_rda = self.z_encoder(mu, sigma) - task_counts = None - elif self.z_dim_rd == 1: + task_counts = None + sampled_rda = None + if self.z_dim_rd == 1: Y = batch[2].to(self.device) nan_y = Y.masked_fill(Y == 0, float("nan"))[:, : self.cfg.dataset.P + self.cfg.dataset.N] task_counts = torch.nanmean(nan_y, dim=1) if self.cfg.model.rda else None sampled_rda = None - else: - task_counts = None - sampled_rda = None + ds_emb = None if self.dataset_token is not None: ds_emb = self.dataset_embedder(ds_emb) - else: - ds_emb = None emb_batches.append(emb.detach().cpu().numpy()) ds_emb_batches.append(ds_emb.detach().cpu().numpy()) - merged_embs = LitUCEModel.resize_batch(emb, gene_embeds, task_counts, sampled_rda, ds_emb) + merged_embs = StateEmbeddingModel.resize_batch(emb, gene_embeds, task_counts, sampled_rda, ds_emb) logprobs_batch = self.binary_decoder(merged_embs) logprobs_batch = logprobs_batch.detach().cpu().numpy() logprob_batches.append(logprobs_batch.squeeze()) @@ -450,25 +395,16 @@ def forward(self, src: Tensor, mask: Tensor, counts=None, dataset_nums=None): src + count_emb ) # should both be B x H x self.d_model, or B x H + 1 x self.d_model if dataset correction - if self.training: - # random chance 10% to set mask to None - if self.cfg.model.get("variable_masking", False) and np.random.rand() < 0.1: - mask = None - else: - mask = mask.to(self.device) - output = self.transformer_encoder(src, src_key_padding_mask=mask) - else: - output = self.transformer_encoder(src, src_key_padding_mask=None) + output = self.transformer_encoder(src, src_key_padding_mask=None) gene_output = self.decoder(output) # batch x seq_len x 128 # In the new format, the cls token, which is at the 0 index mark, is the output. embedding = gene_output[:, 0, :] # select only the CLS token. embedding = nn.functional.normalize(embedding, dim=1) # Normalize. # we must be in train mode to use dataset correction + dataset_emb = None if self.dataset_token is not None: dataset_emb = gene_output[:, -1, :] - else: - dataset_emb = None return gene_output, embedding, dataset_emb @@ -478,16 +414,7 @@ def shared_step(self, batch, batch_idx): z = embs.unsqueeze(1).repeat(1, X.shape[1], 1) # CLS token - if self.z_dim_rd > 1: - # your code here that computes mu and std dev from Y - nan_y = Y.masked_fill(Y == 0, float("nan"))[:, : self.cfg.dataset.P + self.cfg.dataset.N] - mu = torch.nanmean(nan_y, dim=1) if self.cfg.model.rda else None - sigma = nanstd(nan_y) if self.cfg.model.rda else None - - reshaped_counts = self.z_encoder(mu, sigma).unsqueeze(1) - reshaped_counts = reshaped_counts.expand(X.shape[0], X.shape[1], reshaped_counts.shape[2]) - combine = torch.cat((X, z, reshaped_counts), dim=2) - elif self.z_dim_rd == 1: + if self.z_dim_rd == 1: mu = torch.nanmean(Y.masked_fill(Y == 0, float("nan")), dim=1) if self.cfg.model.rda else None reshaped_counts = mu.unsqueeze(1).unsqueeze(2) reshaped_counts = reshaped_counts.repeat(1, X.shape[1], 1) diff --git a/src/state_sets/state/train/trainer.py b/src/state_sets/state/train/trainer.py index bcd5f615..604a6cbb 100644 --- a/src/state_sets/state/train/trainer.py +++ b/src/state_sets/state/train/trainer.py @@ -11,7 +11,7 @@ from lightning.pytorch.strategies import DDPStrategy from zclip import ZClipLightningCallback -from ..nn.model import LitUCEModel +from ..nn.model import StateEmbeddingModel from ..data import H5adSentenceDataset, VCIDatasetSentenceCollator from ..train.callbacks import LogLR, ProfilerCallback, ResumeCallback, EMACallback, PerfProfilerCallback from ..utils import get_latest_checkpoint, get_embedding_cfg, get_dataset_cfg @@ -74,7 +74,7 @@ def main(cfg): generator=generator, ) - model = LitUCEModel( + model = StateEmbeddingModel( token_dim=get_embedding_cfg(cfg).size, d_model=cfg.model.emsize, nhead=cfg.model.nhead, @@ -141,7 +141,7 @@ def main(cfg): accumulate_grad_batches=cfg.optimizer.gradient_accumulation_steps, precision="bf16-mixed", strategy=DDPStrategy( - process_group_backend="nccl", # if cfg.experiment.num_nodes == 1 else "gloo", + process_group_backend="nccl", find_unused_parameters=False, timeout=timedelta(seconds=cfg.experiment.get("ddp_timeout", 3600)), ), From 9c104ff0c65a09be43a2f1cb59d3ac0b932be987 Mon Sep 17 00:00:00 2001 From: Abhinav Adduri Date: Mon, 23 Jun 2025 01:53:15 -0700 Subject: [PATCH 16/16] updated versioning --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 17d6ba82..f234664a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "state-sets" -version = "0.7.0" +version = "0.8.0" description = "Add your description here" readme = "README.md" authors = [] @@ -28,8 +28,8 @@ dependencies = [ dev = ["ruff>=0.11.11", "vulture>=2.14"] [tool.uv.sources] -cell-load = { git = "ssh://github.com/arcinstitute/cell-load.git" } -cell-eval = { git = "ssh://github.com/arcinstitute/cell-eval.git" } +cell-load = { git = "https://github.com/arcinstitute/cell-load.git" } +cell-eval = { git = "https://github.com/arcinstitute/cell-eval.git" } zclip = { git = "https://github.com/bluorion-com/ZClip.git" } [build-system]