From 0ff67a0d479a1f910361ba9872da6f39e560392b Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Fri, 19 Dec 2025 13:12:03 -0800 Subject: [PATCH 01/10] docs: add repository guidelines for project structure, development commands, coding style, testing, and security --- AGENTS.md | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 AGENTS.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..41642cdb --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,33 @@ +# Repository Guidelines + +## Project Structure & Module Organization +- Core library lives in `src/state`, with CLI entrypoints in `state/__main__.py` and subcommands under `state/_cli`. +- Model configs and resources sit in `src/state/configs`, embeddings helpers in `src/state/emb`, and transition utilities in `src/state/tx`. +- Tests reside in `tests/` (`test_*.py`), runnable without extra fixtures. Example TOML configs are in `examples/`. Helper scripts live in `scripts/` for inference and embedding. +- Artifacts or scratch outputs should go in `tmp/` or a user-created path; keep `assets/` for checked-in visuals/resources only. + +## Build, Test, and Development Commands +- Create/activate env and install in editable mode: `uv tool install -e .`. +- Run the CLI: `uv run state --help` (entrypoints `emb` and `tx`). +- Format/lint: `uv run ruff check .` (auto-fixes enabled by default config). +- Run tests: `uv run pytest` (adds `src/` to `PYTHONPATH` via standard layout). + +## Coding Style & Naming Conventions +- Python 3.10–3.12; prefer type hints on public functions. +- Use 4-space indentation, 120-char max line length (`ruff.toml`), and avoid bare `except` (E722 is explicitly ignored—only use when necessary). +- Modules and files use `snake_case`; classes `CamelCase`; constants `UPPER_SNAKE_CASE`. +- Keep CLI options descriptive and align new configs with the existing TOML examples. + +## Testing Guidelines +- Add unit tests alongside new features in `tests/` with filenames `test_*.py` and functions `test_*`. +- Cover edge cases around data loading, config parsing, and checkpoint handling; favor small fixtures over large data blobs. +- For regressions, reproduce with a failing test first, then implement the fix. + +## Commit & Pull Request Guidelines +- Follow the short, imperative style seen in history (`chore: …`, `patch: …`, or focused message without trailing punctuation). Reference issue/PR numbers where applicable. +- PRs should explain the change, risks, and testing done (`uv run pytest`, `uv run ruff check .`). Include CLI examples if you changed commands or configs. +- Keep diffs scoped; split unrelated changes into separate PRs. Include screenshots or logs only when UI/output changes are relevant. + +## Security & Configuration Tips +- Do not commit dataset paths or secrets; use environment variables or local config files kept out of git. +- Validate file paths in new CLI options and prefer existing config loaders under `state/_cli` to avoid duplicating logic. From 03b8511d8e8a4a44f6cc58590f95f07bcf47bbe9 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Fri, 19 Dec 2025 14:26:25 -0800 Subject: [PATCH 02/10] feat: enhance HVG handling in TX workflows - Added a `--verbose` flag to the inference CLI for detailed gene name mapping output. - Implemented HVG name retrieval and validation in preprocessing and prediction scripts. - Introduced constants for HVG variable names and updated relevant functions to utilize them. - Enhanced logging for HVG name availability and warnings for missing data. - Updated dataset class to default to the new HVG names key. --- src/state/_cli/_tx/_infer.py | 26 +++++++ src/state/_cli/_tx/_predict.py | 26 ++++++- src/state/_cli/_tx/_preprocess_infer.py | 12 +++ src/state/_cli/_tx/_preprocess_train.py | 13 +++- src/state/tx/constants.py | 4 + .../dataset/scgpt_perturbation_dataset.py | 3 +- src/state/tx/utils/hvg.py | 78 +++++++++++++++++++ 7 files changed, 155 insertions(+), 7 deletions(-) create mode 100644 src/state/tx/constants.py create mode 100644 src/state/tx/utils/hvg.py diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index a4f05d04..233fe6f1 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -76,6 +76,11 @@ def add_arguments_infer(parser: argparse.ArgumentParser): action="store_true", help="Reduce logging verbosity.", ) + parser.add_argument( + "--verbose", + action="store_true", + help="Show extra details about gene name mapping.", + ) parser.add_argument( "--tsv", type=str, @@ -96,6 +101,7 @@ def run_tx_infer(args: argparse.Namespace): from tqdm import tqdm from ...tx.models.state_transition import StateTransitionPerturbationModel + from ...tx.utils.hvg import get_hvg_var_names # ----------------------- # Helpers @@ -396,6 +402,25 @@ def pad_adata_with_tsv( # ----------------------- adata = sc.read_h5ad(args.adata) + hvg_names_status = "n/a" + if args.embed_key == "X_hvg": + hvg_names = get_hvg_var_names(adata, obsm_key="X_hvg") + if hvg_names is None and not args.quiet: + print( + "Warning: adata.uns['X_hvg_var_names'] not found. " + "Downstream analysis (e.g., pdex) may not be able to map predictions to gene names. " + "Consider re-running preprocess_train with the latest STATE version." + ) + if hvg_names is not None: + hvg_names_status = "present" + else: + hvg_names_status = "missing" + if args.verbose and not args.quiet: + if hvg_names is not None: + print(f"HVG gene names found for X_hvg: {len(hvg_names)} entries.") + else: + print("HVG gene names not found for X_hvg.") + # optional TSV padding mode - pad with additional perturbation cells if args.tsv: if not args.quiet: @@ -669,3 +694,4 @@ def group_control_indices(group_name: str) -> np.ndarray: print(f"Treated simulated: {n_nonctl}") print(f"Wrote predictions to adata.{out_target}") print(f"Saved: {output_path}") + print(f"HVG names: {hvg_names_status}") diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 6967978f..8711d227 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -65,6 +65,8 @@ def run_tx_predict(args: ap.ArgumentParser): import torch import yaml + from state.tx.constants import HVG_VAR_NAMES_KEY + # Cell-eval for metrics computation from cell_eval import MetricsEvaluator from cell_eval.utils import split_anndata_on_celltype @@ -325,6 +327,9 @@ def load_config(cfg_path: str) -> dict: obs = pd.DataFrame(df_dict) gene_names = var_dims["gene_names"] + hvg_uns_names = None + if data_module.embed_key == "X_hvg" or cfg["data"]["kwargs"]["output_space"] == "gene": + hvg_uns_names = gene_names var = pd.DataFrame({"gene_names": gene_names}) if final_X_hvg is not None: @@ -332,6 +337,7 @@ def load_config(cfg_path: str) -> dict: gene_names = np.load( "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True ) + hvg_uns_names = gene_names var = pd.DataFrame({"gene_names": gene_names}) # Create adata for predictions - using the decoded gene expression values @@ -343,6 +349,11 @@ def load_config(cfg_path: str) -> dict: adata_pred.obsm[data_module.embed_key] = final_preds adata_real.obsm[data_module.embed_key] = final_reals logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") + + if hvg_uns_names is not None: + hvg_uns_array = np.array(hvg_uns_names, dtype=object) + adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array + adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) else: # if len(gene_names) != final_preds.shape[1]: # gene_names = np.load( @@ -350,12 +361,19 @@ def load_config(cfg_path: str) -> dict: # ) # var = pd.DataFrame({"gene_names": gene_names}) + var = None + if len(gene_names) == final_preds.shape[1]: + var = pd.DataFrame({"gene_names": gene_names}) + # Create adata for predictions - model was trained on gene expression space already - # adata_pred = anndata.AnnData(X=final_preds, obs=obs, var=var) - adata_pred = anndata.AnnData(X=final_preds, obs=obs) + adata_pred = anndata.AnnData(X=final_preds, obs=obs, var=var) # Create adata for real - using the true gene expression values - # adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) - adata_real = anndata.AnnData(X=final_reals, obs=obs) + adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) + + if hvg_uns_names is not None: + hvg_uns_array = np.array(hvg_uns_names, dtype=object) + adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array + adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) # Optionally filter to perturbations seen in at least one training context if args.shared_only: diff --git a/src/state/_cli/_tx/_preprocess_infer.py b/src/state/_cli/_tx/_preprocess_infer.py index e9f5df38..e8c6d46f 100644 --- a/src/state/_cli/_tx/_preprocess_infer.py +++ b/src/state/_cli/_tx/_preprocess_infer.py @@ -79,11 +79,17 @@ def run_tx_preprocess_infer( import numpy as np # tqdm removed from the hot path; the main speed-up is vectorization, not progress bars. + from state.tx.utils.hvg import get_hvg_var_names + logger = logging.getLogger(__name__) print(f"Loading AnnData from {adata_path}") adata = ad.read_h5ad(adata_path) + hvg_names = get_hvg_var_names(adata) + if hvg_names is not None: + logger.info("Found %d HVG names in adata.uns for X_hvg", len(hvg_names)) + # Set random seed for reproducibility rng = np.random.default_rng(seed) print(f"Set random seed to {seed}") @@ -94,6 +100,12 @@ def run_tx_preprocess_infer( if embed_key is not None and embed_key not in adata.obsm: raise KeyError(f"obsm key '{embed_key}' not found in adata.obsm") + if embed_key == "X_hvg" and hvg_names is None: + logger.warning( + "Warning: adata.uns['X_hvg_var_names'] not found. " + "Downstream analysis (e.g., pdex) may not be able to map predictions to gene names. " + "Consider re-running preprocess_train with the latest STATE version." + ) # Identify control cells print(f"Identifying control cells with condition: {control_condition!r}") diff --git a/src/state/_cli/_tx/_preprocess_train.py b/src/state/_cli/_tx/_preprocess_train.py index 8d9d9f8e..3911e5a6 100644 --- a/src/state/_cli/_tx/_preprocess_train.py +++ b/src/state/_cli/_tx/_preprocess_train.py @@ -26,6 +26,7 @@ def add_arguments_preprocess_train(parser: ap.ArgumentParser): def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): """ Preprocess training data by normalizing, log-transforming, and selecting highly variable genes. + Stores HVG names in .uns["X_hvg_var_names"] for downstream mapping. Args: adata_path: Path to input AnnData file @@ -35,8 +36,11 @@ def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): import logging import anndata as ad + import numpy as np import scanpy as sc + from state.tx.constants import HVG_OBSM_KEY, HVG_VAR_NAMES_KEY + logger = logging.getLogger(__name__) logger.info(f"Loading AnnData from {adata_path}") @@ -51,8 +55,13 @@ def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): logger.info(f"Finding top {num_hvgs} highly variable genes") sc.pp.highly_variable_genes(adata, n_top_genes=num_hvgs) - logger.info("Storing highly variable genes in .obsm['X_hvg']") - adata.obsm["X_hvg"] = adata[:, adata.var.highly_variable].X.toarray() + logger.info(f"Storing highly variable genes in .obsm['{HVG_OBSM_KEY}']") + adata.obsm[HVG_OBSM_KEY] = adata[:, adata.var.highly_variable].X.toarray() + + # Store HVG names alongside X_hvg for downstream gene mapping. + hvg_gene_names = adata.var_names[adata.var.highly_variable].tolist() + adata.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_gene_names, dtype=object) + logger.info(f"Stored {len(hvg_gene_names)} HVG names in adata.uns['{HVG_VAR_NAMES_KEY}']") logger.info(f"Saving preprocessed data to {output_path}") adata.write_h5ad(output_path) diff --git a/src/state/tx/constants.py b/src/state/tx/constants.py new file mode 100644 index 00000000..7a47d146 --- /dev/null +++ b/src/state/tx/constants.py @@ -0,0 +1,4 @@ +"""Shared constants for TX workflows and storage conventions.""" + +HVG_VAR_NAMES_KEY = "X_hvg_var_names" +HVG_OBSM_KEY = "X_hvg" diff --git a/src/state/tx/data/dataset/scgpt_perturbation_dataset.py b/src/state/tx/data/dataset/scgpt_perturbation_dataset.py index 96a47fd2..2424aac6 100644 --- a/src/state/tx/data/dataset/scgpt_perturbation_dataset.py +++ b/src/state/tx/data/dataset/scgpt_perturbation_dataset.py @@ -53,7 +53,7 @@ def __init__( should_yield_control_cells: bool = True, store_raw_basal: bool = False, vocab: Optional[Dict[str, int]] = None, - hvg_names_uns_key: Optional[str] = None, + hvg_names_uns_key: Optional[str] = "X_hvg_var_names", perturbation_type: Literal["chemical", "genetic"] = "chemical", **kwargs, ): @@ -73,6 +73,7 @@ def __init__( random_state: Random seed for reproducibility pert_tracker: PerturbationTracker instance for tracking valid perturbations should_yield_control_cells: If True, control cells will be included in the dataset + hvg_names_uns_key: Optional uns key holding HVG names (default: "X_hvg_var_names") """ super().__init__( name=name, diff --git a/src/state/tx/utils/hvg.py b/src/state/tx/utils/hvg.py new file mode 100644 index 00000000..0c62a4e0 --- /dev/null +++ b/src/state/tx/utils/hvg.py @@ -0,0 +1,78 @@ +"""Helpers for retrieving and validating HVG gene names.""" + +from __future__ import annotations + +import logging +from anndata import AnnData + +from state.tx.constants import HVG_VAR_NAMES_KEY + + +logger = logging.getLogger(__name__) + + +def get_hvg_var_names(adata: AnnData, obsm_key: str = "X_hvg") -> list[str] | None: + """Return HVG gene names for an embedding. + + Args: + adata: AnnData to inspect. + obsm_key: Embedding key to resolve gene names for. + + Returns: + List of gene names if available, otherwise None. + """ + version = detect_preprocessing_version(adata, obsm_key=obsm_key) + if version in {"legacy_uns", "var_only"}: + logger.warning( + "Detected legacy HVG metadata for %s. Consider re-running preprocess_train with the latest STATE version.", + obsm_key, + ) + + derived_key = f"{obsm_key}_var_names" + if derived_key in adata.uns: + logger.info("Using HVG var names from adata.uns['%s']", derived_key) + return list(adata.uns[derived_key]) + + if HVG_VAR_NAMES_KEY in adata.uns: + logger.info("Using HVG var names from adata.uns['%s']", HVG_VAR_NAMES_KEY) + return list(adata.uns[HVG_VAR_NAMES_KEY]) + + if "highly_variable" in adata.var: + logger.info("Using HVG var names from adata.var['highly_variable']") + return adata.var_names[adata.var["highly_variable"]].tolist() + + logger.info("No HVG var names available for adata.obsm['%s']", obsm_key) + return None + + +def detect_preprocessing_version(adata: AnnData, obsm_key: str = "X_hvg") -> str: + """Detect the preprocessing metadata format based on uns/var keys. + + Args: + adata: AnnData to inspect. + obsm_key: Embedding key to resolve gene names for. + + Returns: + One of: "current", "legacy_uns", "var_only", or "unknown". + """ + derived_key = f"{obsm_key}_var_names" + if derived_key in adata.uns: + return "current" + if HVG_VAR_NAMES_KEY in adata.uns: + return "legacy_uns" + if "highly_variable" in adata.var: + return "var_only" + return "unknown" + + +def validate_hvg_var_names(adata: AnnData, obsm_key: str = "X_hvg") -> bool: + """Validate whether HVG gene names can be resolved for an embedding. + + Args: + adata: AnnData to inspect. + obsm_key: Embedding key to validate gene names for. + + Returns: + True when gene names are available, otherwise False. + """ + return get_hvg_var_names(adata, obsm_key=obsm_key) is not None From 916c07cce9b3239d1772c1da061d1950db291391 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Fri, 19 Dec 2025 14:36:02 -0800 Subject: [PATCH 03/10] feat: store and access HVG gene names in AnnData - Added storage of highly variable gene (HVG) names in `adata.uns["X_hvg_var_names"]` for improved downstream mapping. - Updated documentation to reflect changes in HVG gene name access and backward compatibility. - Introduced tests to validate HVG name retrieval and ensure compatibility with existing workflows. - Enhanced inference and preprocessing scripts to preserve HVG names during data processing. --- README.md | 15 +++ docs/migration/hvg_var_names.md | 64 +++++++++++++ tests/test_hvg_utils.py | 38 ++++++++ tests/test_inference_pipeline.py | 139 ++++++++++++++++++++++++++++ tests/test_predict_hvg_names.py | 153 +++++++++++++++++++++++++++++++ tests/test_preprocess_infer.py | 29 ++++++ tests/test_preprocess_train.py | 40 ++++++++ 7 files changed, 478 insertions(+) create mode 100644 docs/migration/hvg_var_names.md create mode 100644 tests/test_hvg_utils.py create mode 100644 tests/test_inference_pipeline.py create mode 100644 tests/test_predict_hvg_names.py create mode 100644 tests/test_preprocess_infer.py create mode 100644 tests/test_preprocess_train.py diff --git a/README.md b/README.md index 85a5528b..92609689 100644 --- a/README.md +++ b/README.md @@ -129,6 +129,21 @@ This command: - Applies log1p transformation (`sc.pp.log1p`) - Identifies highly variable genes (`sc.pp.highly_variable_genes`) - Stores the HVG expression matrix in `.obsm['X_hvg']` +- Stores HVG gene names in `.uns['X_hvg_var_names']` for downstream mapping + +#### Accessing HVG Gene Names + +The HVG gene names associated with `adata.obsm["X_hvg"]` are stored in `adata.uns["X_hvg_var_names"]`. +This makes it easy to construct downstream AnnData objects for tools like `pdex`: + +```python +hvg_names = adata.uns.get("X_hvg_var_names") +adata_for_pdex = ad.AnnData( + X=adata.obsm["X_hvg"], + obs=adata.obs, + var=pd.DataFrame(index=hvg_names), +) +``` #### Inference Data Preprocessing diff --git a/docs/migration/hvg_var_names.md b/docs/migration/hvg_var_names.md new file mode 100644 index 00000000..29dcd690 --- /dev/null +++ b/docs/migration/hvg_var_names.md @@ -0,0 +1,64 @@ +# Migration: HVG Gene Names Stored in AnnData Uns + +## Summary + +Recent versions of STATE store highly variable gene (HVG) names in `adata.uns["X_hvg_var_names"]`. +This makes it possible for downstream tools to map `adata.obsm["X_hvg"]` columns back to gene IDs. + +If you have preprocessed data created before this change, you can backfill the HVG names with the +script below. + +## Backward Compatibility + +This change is fully backward compatible: + +- **Existing preprocessed data**: Inference commands continue to work without modification. A + non-blocking warning is emitted recommending re-preprocessing, but execution proceeds normally. +- **Existing trained models**: Model checkpoints do not depend on this uns key. Gene names are + already captured in `var_dims.pkl` at training time. +- **Downstream code**: Code unaware of `X_hvg_var_names` simply ignores it. The obsm matrix + structure is unchanged. + +### Fallback Behavior + +When `X_hvg_var_names` is absent, STATE attempts to recover gene names from +`adata.var_names[adata.var.highly_variable]`. This fallback succeeds as long as the +`highly_variable` boolean column remains in `adata.var`. + +### When Gene Names Are Unrecoverable + +Gene names cannot be recovered if an h5ad file has `X_hvg` in obsm but: + +1. No `X_hvg_var_names` in uns, AND +2. No `highly_variable` column in var (e.g., var was subset or modified) + +This edge case would already be broken prior to this change. The new feature makes the mapping +explicit rather than implicit. + +## Backfill Script + +For existing preprocessed files, run the following to add `X_hvg_var_names`: +```python +import anndata as ad +import numpy as np + +adata = ad.read_h5ad("your_preprocessed_data.h5ad") + +if "X_hvg" in adata.obsm and "X_hvg_var_names" not in adata.uns: + if "highly_variable" in adata.var.columns: + hvg_names = adata.var_names[adata.var.highly_variable].tolist() + adata.uns["X_hvg_var_names"] = np.array(hvg_names, dtype=object) + adata.write_h5ad("your_preprocessed_data.h5ad") + print(f"Added {len(hvg_names)} HVG names to uns") + else: + print("Cannot backfill: 'highly_variable' column not found in adata.var") +else: + print("Backfill not needed or X_hvg not present") +``` + +## Notes + +- The uns key is stored as a NumPy array of Python strings for h5ad compatibility. +- Re-running `state tx preprocess_train` with the latest version will populate this automatically. +- The naming convention `{obsm_key}_var_names` allows for multiple obsm matrices with associated + gene names (e.g., `X_pca_var_names` if needed in the future). \ No newline at end of file diff --git a/tests/test_hvg_utils.py b/tests/test_hvg_utils.py new file mode 100644 index 00000000..30714589 --- /dev/null +++ b/tests/test_hvg_utils.py @@ -0,0 +1,38 @@ +import numpy as np +import anndata as ad + +from state.tx.utils.hvg import get_hvg_var_names, validate_hvg_var_names + + +def test_get_hvg_var_names_prefers_obsm_key(): + adata = ad.AnnData(X=np.ones((2, 2))) + adata.var_names = ["g1", "g2"] + adata.uns["X_custom_var_names"] = np.array(["a", "b"], dtype=object) + + names = get_hvg_var_names(adata, obsm_key="X_custom") + assert names == ["a", "b"] + + +def test_get_hvg_var_names_falls_back_to_highly_variable(): + adata = ad.AnnData(X=np.ones((3, 3))) + adata.var_names = ["g1", "g2", "g3"] + adata.var["highly_variable"] = [True, False, True] + + names = get_hvg_var_names(adata) + assert names == ["g1", "g3"] + + +def test_get_hvg_var_names_returns_none_when_missing(): + adata = ad.AnnData(X=np.ones((2, 2))) + adata.var_names = ["g1", "g2"] + + assert get_hvg_var_names(adata) is None + + +def test_validate_hvg_var_names(): + adata = ad.AnnData(X=np.ones((2, 2))) + adata.var_names = ["g1", "g2"] + assert validate_hvg_var_names(adata) is False + + adata.uns["X_hvg_var_names"] = np.array(["g1", "g2"], dtype=object) + assert validate_hvg_var_names(adata) is True diff --git a/tests/test_inference_pipeline.py b/tests/test_inference_pipeline.py new file mode 100644 index 00000000..01adbee9 --- /dev/null +++ b/tests/test_inference_pipeline.py @@ -0,0 +1,139 @@ +import argparse +import pickle +from pathlib import Path + +import anndata as ad +import numpy as np +import torch +import yaml + +from state._cli._tx._infer import run_tx_infer +from state._cli._tx._preprocess_train import run_tx_preprocess_train +from state.tx.constants import HVG_VAR_NAMES_KEY + + +class DummyModel(torch.nn.Module): + def __init__(self, output_dim: int): + super().__init__() + self.weight = torch.nn.Parameter(torch.zeros(1)) + self.batch_encoder = None + self.cell_sentence_len = 2 + self.output_space = "gene" + self._output_dim = output_dim + + def predict_step(self, batch, batch_idx=0, padded=False): + ctrl = batch["ctrl_cell_emb"] + return {"preds": ctrl.clone()} + + +def _write_model_assets(model_dir: Path, output_dim: int, pert_dim: int = 2): + config = { + "data": { + "kwargs": { + "control_pert": "ctrl", + "output_space": "gene", + "cell_type_key": "cell_type", + } + }, + "model": {"kwargs": {}}, + } + with open(model_dir / "config.yaml", "w") as f: + yaml.safe_dump(config, f) + + var_dims = {"pert_dim": pert_dim, "batch_dim": None, "output_dim": output_dim} + with open(model_dir / "var_dims.pkl", "wb") as f: + pickle.dump(var_dims, f) + + pert_onehot_map = { + "ctrl": torch.tensor([1.0, 0.0]), + "pert": torch.tensor([0.0, 1.0]), + } + torch.save(pert_onehot_map, model_dir / "pert_onehot_map.pt") + + +def _make_args(model_dir: Path, adata_path: Path, output_path: Path, quiet: bool, verbose: bool): + return argparse.Namespace( + checkpoint=str(model_dir / "checkpoints" / "final.ckpt"), + adata=str(adata_path), + embed_key="X_hvg", + pert_col="pert", + output=str(output_path), + model_dir=str(model_dir), + celltype_col="cell_type", + celltypes=None, + batch_col=None, + control_pert="ctrl", + seed=42, + max_set_len=2, + quiet=quiet, + tsv=None, + verbose=verbose, + ) + + +def test_infer_preserves_hvg_names(monkeypatch, tmp_path): + model_dir = tmp_path / "model" + model_dir.mkdir() + (model_dir / "checkpoints").mkdir() + (model_dir / "checkpoints" / "final.ckpt").touch() + + X = np.array([[1.0, 2.0, 3.0], [3.0, 4.0, 5.0], [5.0, 6.0, 7.0]]) + raw = ad.AnnData(X=X) + raw.obs["pert"] = ["ctrl", "pert", "pert"] + raw.obs["cell_type"] = ["A", "A", "B"] + raw.var_names = ["g1", "g2", "g3"] + + raw_path = tmp_path / "raw.h5ad" + preprocessed_path = tmp_path / "preprocessed.h5ad" + output_path = tmp_path / "output.h5ad" + raw.write_h5ad(raw_path) + + run_tx_preprocess_train(str(raw_path), str(preprocessed_path), num_hvgs=2) + + _write_model_assets(model_dir, output_dim=2) + + dummy = DummyModel(output_dim=2) + monkeypatch.setattr( + "state.tx.models.state_transition.StateTransitionPerturbationModel.load_from_checkpoint", + lambda *args, **kwargs: dummy, + ) + + args = _make_args(model_dir, preprocessed_path, output_path, quiet=True, verbose=False) + run_tx_infer(args) + + out = ad.read_h5ad(output_path) + assert HVG_VAR_NAMES_KEY in out.uns + assert len(out.uns[HVG_VAR_NAMES_KEY]) == out.obsm["X_hvg"].shape[1] + + +def test_infer_warns_when_hvg_missing(monkeypatch, tmp_path, capsys): + model_dir = tmp_path / "model" + model_dir.mkdir() + (model_dir / "checkpoints").mkdir() + (model_dir / "checkpoints" / "final.ckpt").touch() + + X = np.array([[1.0, 2.0], [3.0, 4.0]]) + adata = ad.AnnData(X=X) + adata.obs["pert"] = ["ctrl", "pert"] + adata.obs["cell_type"] = ["A", "A"] + adata.obsm["X_hvg"] = X.copy() + + adata_path = tmp_path / "input_missing.h5ad" + output_path = tmp_path / "output_missing.h5ad" + adata.write_h5ad(adata_path) + + _write_model_assets(model_dir, output_dim=X.shape[1]) + + dummy = DummyModel(output_dim=X.shape[1]) + monkeypatch.setattr( + "state.tx.models.state_transition.StateTransitionPerturbationModel.load_from_checkpoint", + lambda *args, **kwargs: dummy, + ) + + args = _make_args(model_dir, adata_path, output_path, quiet=False, verbose=True) + run_tx_infer(args) + + captured = capsys.readouterr() + combined = captured.out + captured.err + assert "Warning: adata.uns['X_hvg_var_names'] not found" in combined + assert "HVG names:" in combined diff --git a/tests/test_predict_hvg_names.py b/tests/test_predict_hvg_names.py new file mode 100644 index 00000000..9ab8baf4 --- /dev/null +++ b/tests/test_predict_hvg_names.py @@ -0,0 +1,153 @@ +import types +from pathlib import Path + +import numpy as np +import anndata as ad +import torch +import yaml + +from state._cli._tx._predict import run_tx_predict +from state.tx.constants import HVG_VAR_NAMES_KEY + + +class DummyBatchSampler: + def __init__(self, tot_num: int): + self.tot_num = tot_num + + +class DummyLoader: + def __init__(self, batch): + self.batch_sampler = DummyBatchSampler(batch["pert_cell_emb"].shape[0]) + self._batch = batch + + def __iter__(self): + yield self._batch + + +class DummyDataModule: + def __init__(self, gene_names): + self.embed_key = "X_hvg" + self.pert_col = "pert" + self.cell_type_key = "cell_type" + self.batch_col = "batch" + self._gene_names = gene_names + self.batch_size = 1 + + def setup(self, stage="test"): + return None + + def get_var_dims(self): + return { + "input_dim": 2, + "gene_dim": 2, + "hvg_dim": 2, + "output_dim": 2, + "pert_dim": 2, + "gene_names": self._gene_names, + } + + def test_dataloader(self): + batch = { + "pert_cell_emb": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + "ctrl_cell_emb": torch.tensor([[1.0, 2.0], [3.0, 4.0]]), + } + return DummyLoader(batch) + + def train_dataloader(self, test=False): + return self.test_dataloader() + + def get_control_pert(self): + return "ctrl" + + def get_shared_perturbations(self): + return [] + + +class DummyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.zeros(1)) + + def predict_step(self, batch, batch_idx, padded=False): + preds = batch["pert_cell_emb"].clone() + return { + "pert_name": ["ctrl", "pert"], + "celltype_name": ["A", "A"], + "batch": torch.tensor([0, 0]), + "preds": preds, + "pert_cell_emb": batch["pert_cell_emb"], + } + + +def _install_dummy_modules(monkeypatch, data_module): + cell_eval = types.ModuleType("cell_eval") + cell_eval.MetricsEvaluator = object + cell_eval_utils = types.ModuleType("cell_eval.utils") + cell_eval_utils.split_anndata_on_celltype = lambda adata, celltype_col: {"all": adata} + monkeypatch.setitem(__import__("sys").modules, "cell_eval", cell_eval) + monkeypatch.setitem(__import__("sys").modules, "cell_eval.utils", cell_eval_utils) + + cell_load = types.ModuleType("cell_load") + cell_load_data = types.ModuleType("cell_load.data_modules") + + class DummyPerturbationDataModule: + @staticmethod + def load_state(_path): + return data_module + + cell_load_data.PerturbationDataModule = DummyPerturbationDataModule + monkeypatch.setitem(__import__("sys").modules, "cell_load", cell_load) + monkeypatch.setitem(__import__("sys").modules, "cell_load.data_modules", cell_load_data) + + +def test_predict_outputs_hvg_names(monkeypatch, tmp_path): + output_dir = tmp_path / "run" + output_dir.mkdir() + + cfg = { + "output_dir": str(tmp_path), + "name": "run", + "training": {"train_seed": 0}, + "model": {"name": "state", "kwargs": {"hidden_dim": 4}}, + "data": {"kwargs": {"output_space": "gene"}}, + } + config_path = tmp_path / "config.yaml" + with open(config_path, "w") as f: + yaml.safe_dump(cfg, f) + + run_output_dir = tmp_path / "run" + run_output_dir.mkdir(exist_ok=True) + (run_output_dir / "data_module.torch").touch() + checkpoints_dir = run_output_dir / "checkpoints" + checkpoints_dir.mkdir() + (checkpoints_dir / "last.ckpt").touch() + + gene_names = ["g1", "g2"] + data_module = DummyDataModule(gene_names) + _install_dummy_modules(monkeypatch, data_module) + + monkeypatch.setattr( + "state.tx.models.state_transition.StateTransitionPerturbationModel.load_from_checkpoint", + lambda *args, **kwargs: DummyModel(), + ) + + args = types.SimpleNamespace( + output_dir=str(tmp_path), + checkpoint="last.ckpt", + test_time_finetune=0, + profile="anndata", + predict_only=True, + shared_only=False, + eval_train_data=False, + ) + + run_tx_predict(args) + + results_dir = Path(tmp_path) / "eval_last.ckpt" + adata_pred = ad.read_h5ad(results_dir / "adata_pred.h5ad") + adata_real = ad.read_h5ad(results_dir / "adata_real.h5ad") + + assert HVG_VAR_NAMES_KEY in adata_pred.uns + assert HVG_VAR_NAMES_KEY in adata_real.uns + assert adata_pred.uns[HVG_VAR_NAMES_KEY].tolist() == gene_names + assert adata_real.uns[HVG_VAR_NAMES_KEY].tolist() == gene_names diff --git a/tests/test_preprocess_infer.py b/tests/test_preprocess_infer.py new file mode 100644 index 00000000..e7ad9279 --- /dev/null +++ b/tests/test_preprocess_infer.py @@ -0,0 +1,29 @@ +import numpy as np +import anndata as ad + +from state._cli._tx._preprocess_infer import run_tx_preprocess_infer +from state.tx.constants import HVG_VAR_NAMES_KEY + + +def test_preprocess_infer_preserves_uns(tmp_path): + X = np.array([[1.0, 2.0], [3.0, 4.0]]) + adata = ad.AnnData(X=X) + adata.obs["pert"] = ["ctrl", "pert"] + adata.uns[HVG_VAR_NAMES_KEY] = np.array(["g1", "g2"], dtype=object) + + input_path = tmp_path / "input.h5ad" + output_path = tmp_path / "output.h5ad" + adata.write_h5ad(input_path) + + run_tx_preprocess_infer( + adata_path=str(input_path), + output_path=str(output_path), + control_condition="ctrl", + pert_col="pert", + seed=0, + embed_key=None, + ) + + processed = ad.read_h5ad(output_path) + assert HVG_VAR_NAMES_KEY in processed.uns + assert processed.uns[HVG_VAR_NAMES_KEY].tolist() == ["g1", "g2"] diff --git a/tests/test_preprocess_train.py b/tests/test_preprocess_train.py new file mode 100644 index 00000000..2d4c0830 --- /dev/null +++ b/tests/test_preprocess_train.py @@ -0,0 +1,40 @@ +import numpy as np +import anndata as ad + +from state._cli._tx._preprocess_train import run_tx_preprocess_train +from state.tx.constants import HVG_OBSM_KEY, HVG_VAR_NAMES_KEY + + +def test_preprocess_train_stores_hvg_names(tmp_path): + X = np.array( + [ + [1, 0, 3, 4, 0, 2], + [2, 1, 0, 1, 0, 3], + [0, 2, 1, 0, 4, 1], + [3, 0, 2, 2, 1, 0], + [1, 1, 1, 1, 1, 1], + ], + dtype=float, + ) + var_names = [f"gene_{i}" for i in range(X.shape[1])] + adata = ad.AnnData(X=X) + adata.var_names = var_names + + input_path = tmp_path / "input.h5ad" + output_path = tmp_path / "output.h5ad" + adata.write_h5ad(input_path) + + run_tx_preprocess_train(str(input_path), str(output_path), num_hvgs=3) + + processed = ad.read_h5ad(output_path) + assert HVG_OBSM_KEY in processed.obsm + assert HVG_VAR_NAMES_KEY in processed.uns + + hvg_names = processed.uns[HVG_VAR_NAMES_KEY] + assert isinstance(hvg_names, np.ndarray) + assert hvg_names.dtype == object + assert len(hvg_names) == processed.obsm[HVG_OBSM_KEY].shape[1] + assert all(isinstance(name, str) for name in hvg_names) + + expected_names = processed.var_names[processed.var.highly_variable].tolist() + assert set(hvg_names.tolist()) == set(expected_names) From 4750d7ab760781494575aef999f9ee37d7b3147d Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Fri, 19 Dec 2025 14:58:18 -0800 Subject: [PATCH 04/10] refactor: streamline HVG name handling in prediction script - Moved HVG name assignment to a single conditional block for clarity and consistency. - Removed redundant code for HVG name storage in `adata.uns` to enhance maintainability. - Cleaned up test file by removing unused numpy import. --- src/state/_cli/_tx/_predict.py | 13 ++++--------- tests/test_predict_hvg_names.py | 1 - 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 8711d227..f8809f0b 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -349,11 +349,6 @@ def load_config(cfg_path: str) -> dict: adata_pred.obsm[data_module.embed_key] = final_preds adata_real.obsm[data_module.embed_key] = final_reals logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") - - if hvg_uns_names is not None: - hvg_uns_array = np.array(hvg_uns_names, dtype=object) - adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array - adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) else: # if len(gene_names) != final_preds.shape[1]: # gene_names = np.load( @@ -370,10 +365,10 @@ def load_config(cfg_path: str) -> dict: # Create adata for real - using the true gene expression values adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) - if hvg_uns_names is not None: - hvg_uns_array = np.array(hvg_uns_names, dtype=object) - adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array - adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) + if hvg_uns_names is not None: + hvg_uns_array = np.array(hvg_uns_names, dtype=object) + adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array + adata_real.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array # Optionally filter to perturbations seen in at least one training context if args.shared_only: diff --git a/tests/test_predict_hvg_names.py b/tests/test_predict_hvg_names.py index 9ab8baf4..5e58b056 100644 --- a/tests/test_predict_hvg_names.py +++ b/tests/test_predict_hvg_names.py @@ -1,7 +1,6 @@ import types from pathlib import Path -import numpy as np import anndata as ad import torch import yaml From 7274ef960a90a5222cf668a9ead1f42dd992105d Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 6 Jan 2026 10:02:38 -0800 Subject: [PATCH 05/10] feat: enhance HVG name handling in inference and prediction scripts - Added a `--verbose` flag to the inference CLI for detailed output on gene name mapping. - Implemented checks and logging for the presence of highly variable gene (HVG) names during inference. - Updated prediction script to store HVG names in `adata.uns` for improved data consistency. - Refactored code to streamline HVG name handling and ensure compatibility with existing workflows. --- src/state/_cli/_tx/_infer.py | 26 ++++++++++++++++++++++++++ src/state/_cli/_tx/_predict.py | 25 +++++++++++++++++++++---- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index c7c15b60..621a8713 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -76,6 +76,11 @@ def add_arguments_infer(parser: argparse.ArgumentParser): action="store_true", help="Reduce logging verbosity.", ) + parser.add_argument( + "--verbose", + action="store_true", + help="Show extra details about gene name mapping.", + ) parser.add_argument( "--tsv", type=str, @@ -119,6 +124,7 @@ def run_tx_infer(args: argparse.Namespace): from tqdm import tqdm from ...tx.models.state_transition import StateTransitionPerturbationModel + from ...tx.utils.hvg import get_hvg_var_names # ----------------------- # Helpers @@ -422,6 +428,25 @@ def pad_adata_with_tsv( # ----------------------- adata = sc.read_h5ad(args.adata) + hvg_names_status = "n/a" + if args.embed_key == "X_hvg": + hvg_names = get_hvg_var_names(adata, obsm_key="X_hvg") + if hvg_names is None and not args.quiet: + print( + "Warning: adata.uns['X_hvg_var_names'] not found. " + "Downstream analysis (e.g., pdex) may not be able to map predictions to gene names. " + "Consider re-running preprocess_train with the latest STATE version." + ) + if hvg_names is not None: + hvg_names_status = "present" + else: + hvg_names_status = "missing" + if args.verbose and not args.quiet: + if hvg_names is not None: + print(f"HVG gene names found for X_hvg: {len(hvg_names)} entries.") + else: + print("HVG gene names not found for X_hvg.") + # optional TSV padding mode - pad with additional perturbation cells if args.tsv: if not args.quiet: @@ -927,3 +952,4 @@ def group_control_indices(group_name: str) -> np.ndarray: print(f"Saved: {output_path}") if counts_written and counts_out_target: print(f"Saved count predictions to adata.{counts_out_target}") + print(f"HVG names: {hvg_names_status}") diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index c9d953da..646982dc 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -72,6 +72,8 @@ def run_tx_predict(args: ap.ArgumentParser): import torch import yaml + from state.tx.constants import HVG_VAR_NAMES_KEY + # Cell-eval for metrics computation from cell_eval import MetricsEvaluator from cell_eval.utils import split_anndata_on_celltype @@ -428,6 +430,9 @@ def normalize_batch_labels(values): obs = pd.DataFrame(df_dict) gene_names = var_dims["gene_names"] + hvg_uns_names = None + if data_module.embed_key == "X_hvg" or cfg["data"]["kwargs"]["output_space"] == "gene": + hvg_uns_names = gene_names var = pd.DataFrame({"gene_names": gene_names}) if final_X_hvg is not None: @@ -446,6 +451,11 @@ def normalize_batch_labels(values): adata_pred.obsm[data_module.embed_key] = final_preds adata_real.obsm[data_module.embed_key] = final_reals logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") + + if hvg_uns_names is not None: + hvg_uns_array = np.array(hvg_uns_names, dtype=object) + adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array + adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) else: # if len(gene_names) != final_preds.shape[1]: # gene_names = np.load( @@ -453,12 +463,19 @@ def normalize_batch_labels(values): # ) # var = pd.DataFrame({"gene_names": gene_names}) + var = None + if len(gene_names) == final_preds.shape[1]: + var = pd.DataFrame({"gene_names": gene_names}) + # Create adata for predictions - model was trained on gene expression space already - # adata_pred = anndata.AnnData(X=final_preds, obs=obs, var=var) - adata_pred = anndata.AnnData(X=final_preds, obs=obs) + adata_pred = anndata.AnnData(X=final_preds, obs=obs, var=var) # Create adata for real - using the true gene expression values - # adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) - adata_real = anndata.AnnData(X=final_reals, obs=obs) + adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) + + if hvg_uns_names is not None: + hvg_uns_array = np.array(hvg_uns_names, dtype=object) + adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array + adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) # Clip extreme values to keep cell-eval log1p checks happy. clip_anndata_values(adata_pred, max_value=14.0) From 4746be4c3aa59875ad62dc3285903c6af3d23776 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 6 Jan 2026 10:13:47 -0800 Subject: [PATCH 06/10] feat: store HVG names in AnnData if available during inference - Added logic to store highly variable gene (HVG) names in `adata.uns` using the defined constant `HVG_VAR_NAMES_KEY`. - Initialized `hvg_names` variable to handle cases where HVG names may not be present. - Enhanced the inference process to ensure HVG names are preserved for downstream analysis. --- src/state/_cli/_tx/_infer.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 621a8713..8043457b 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -124,6 +124,7 @@ def run_tx_infer(args: argparse.Namespace): from tqdm import tqdm from ...tx.models.state_transition import StateTransitionPerturbationModel + from ...tx.constants import HVG_VAR_NAMES_KEY from ...tx.utils.hvg import get_hvg_var_names # ----------------------- @@ -428,6 +429,7 @@ def pad_adata_with_tsv( # ----------------------- adata = sc.read_h5ad(args.adata) + hvg_names = None hvg_names_status = "n/a" if args.embed_key == "X_hvg": hvg_names = get_hvg_var_names(adata, obsm_key="X_hvg") @@ -929,6 +931,10 @@ def group_control_indices(group_name: str) -> np.ndarray: elif output_space == "all": adata.X = sim_counts + # Store HVG names if available + if hvg_names is not None: + adata.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_names, dtype=object) + if output_is_npy: if pred_matrix is None: raise ValueError("Predictions matrix is unavailable; cannot write .npy output") From 0a8bd53f6500bfde21a9d0ccc3b302b292e2e09f Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 6 Jan 2026 10:14:18 -0800 Subject: [PATCH 07/10] chore: update .gitignore to include tasks directory - Added 'tasks/' to .gitignore to prevent tracking of task-related files. - Ensured that temporary files related to tasks are excluded from version control. --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 922c724a..165f8f38 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,4 @@ notebooks/ *.slurm temp wandb/ +tasks/ \ No newline at end of file From 537114745fec46743b5f894bc376e2cff3433979 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 6 Jan 2026 10:20:37 -0800 Subject: [PATCH 08/10] fix: ensure HVG names are only stored when counts match predictions - Updated conditions for storing highly variable gene (HVG) names in `adata.uns` to check that the length of `hvg_uns_names` matches the shape of prediction arrays. - This change prevents potential mismatches and ensures data integrity during the prediction process. --- src/state/_cli/_tx/_predict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 646982dc..bec6d4c4 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -452,7 +452,7 @@ def normalize_batch_labels(values): adata_real.obsm[data_module.embed_key] = final_reals logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") - if hvg_uns_names is not None: + if hvg_uns_names is not None and len(hvg_uns_names) == final_pert_cell_counts_preds.shape[1]: hvg_uns_array = np.array(hvg_uns_names, dtype=object) adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) @@ -472,7 +472,7 @@ def normalize_batch_labels(values): # Create adata for real - using the true gene expression values adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) - if hvg_uns_names is not None: + if hvg_uns_names is not None and len(hvg_uns_names) == final_preds.shape[1]: hvg_uns_array = np.array(hvg_uns_names, dtype=object) adata_pred.uns[HVG_VAR_NAMES_KEY] = hvg_uns_array adata_real.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_uns_names, dtype=object) From 57cb76eac89159f0f7818c3ce25f7ff6fa3ba134 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 6 Jan 2026 10:25:59 -0800 Subject: [PATCH 09/10] refactor: update argument types in prediction functions - Changed the argument type of `run_tx_predict` from `ap.ArgumentParser` to `ap.Namespace` for better clarity and functionality. - Added additional parameters in the `_make_args` function to enhance flexibility in test cases. - Updated test cases to include a new `toml` parameter for improved configuration handling. --- README.md | 8 ++++++++ src/state/_cli/_tx/_predict.py | 2 +- tests/test_inference_pipeline.py | 4 ++++ tests/test_predict_hvg_names.py | 1 + 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 92609689..7192c01f 100644 --- a/README.md +++ b/README.md @@ -130,10 +130,18 @@ This command: - Identifies highly variable genes (`sc.pp.highly_variable_genes`) - Stores the HVG expression matrix in `.obsm['X_hvg']` - Stores HVG gene names in `.uns['X_hvg_var_names']` for downstream mapping +- **Inference/Prediction**: `tx infer` and `tx predict` commands preserve HVG gene names in output AnnData when using `--embed-key X_hvg` #### Accessing HVG Gene Names The HVG gene names associated with `adata.obsm["X_hvg"]` are stored in `adata.uns["X_hvg_var_names"]`. +This applies to: +- **Training data**: Output from `preprocess_train` +- **Inference results**: Output from `tx infer` with `--embed-key X_hvg` +- **Prediction results**: Output from `tx predict` (both prediction and ground truth AnnData) + +**Note**: If HVG gene names are missing during inference, a warning will be displayed recommending re-running `preprocess_train` with the latest STATE version. + This makes it easy to construct downstream AnnData objects for tools like `pdex`: ```python diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index bec6d4c4..34e8f7e9 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -59,7 +59,7 @@ def add_arguments_predict(parser: ap.ArgumentParser): ) -def run_tx_predict(args: ap.ArgumentParser): +def run_tx_predict(args: ap.Namespace): import logging import os import sys diff --git a/tests/test_inference_pipeline.py b/tests/test_inference_pipeline.py index 01adbee9..591c7228 100644 --- a/tests/test_inference_pipeline.py +++ b/tests/test_inference_pipeline.py @@ -68,6 +68,10 @@ def _make_args(model_dir: Path, adata_path: Path, output_path: Path, quiet: bool quiet=quiet, tsv=None, verbose=verbose, + all_perts=False, + virtual_cells_per_pert=None, + min_cells=None, + max_cells=None, ) diff --git a/tests/test_predict_hvg_names.py b/tests/test_predict_hvg_names.py index 5e58b056..e02cf944 100644 --- a/tests/test_predict_hvg_names.py +++ b/tests/test_predict_hvg_names.py @@ -138,6 +138,7 @@ def test_predict_outputs_hvg_names(monkeypatch, tmp_path): predict_only=True, shared_only=False, eval_train_data=False, + toml=None, ) run_tx_predict(args) From 98d205db2eb9fdb69e02ea822552de969bb2deaf Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 6 Jan 2026 10:31:17 -0800 Subject: [PATCH 10/10] docs: update README with command syntax and configuration details - Revised command examples for model inference and embedding transformation to reflect updated argument names and paths. - Enhanced clarity in data splitting logic and configuration validation sections, including required and optional parameters. - Added new sections for preprocessing datasets and evaluating embedding models, providing users with comprehensive guidance on usage. --- README.md | 117 +++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 85 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 7192c01f..c7afd92a 100644 --- a/README.md +++ b/README.md @@ -104,7 +104,13 @@ in the TOML file: ```bash -state tx infer --output $HOME/state/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/final.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg +state tx infer \ + --model-dir /path/to/model/ \ + --checkpoint /path/to/model/final.ckpt \ + --adata /path/to/anndata/processed.h5ad \ + --pert-col gene \ + --embed-key X_hvg \ + --output /path/to/output/simulated.h5ad ``` Here, `/path/to/model/` is the folder downloaded from [HuggingFace](https://huggingface.co/arcinstitute). @@ -161,8 +167,8 @@ Use `preprocess_infer` to create a "control template" for model inference: state tx preprocess_infer \ --adata /path/to/real_data.h5ad \ --output /path/to/control_template.h5ad \ - --control_condition "DMSO" \ - --pert_col "treatment" \ + --control-condition "DMSO" \ + --pert-col "treatment" \ --seed 42 ``` @@ -269,19 +275,26 @@ test = ["MYC", "TP53"] ### Important Notes -- **Automatic training assignment**: Any cell type not mentioned in `[zeroshot]` automatically participates in training, with perturbations not listed in `[fewshot]` going to the training set -- **Overlapping splits**: Perturbations can appear in both validation and test sets within fewshot configurations -- **Dataset naming**: Use the format `"dataset_name.cell_type"` when specifying cell types in zeroshot and fewshot sections -- **Path requirements**: Dataset paths should point to directories containing h5ad files -- **Control perturbations**: Ensure your control condition (specified via `control_pert` parameter) is available across all splits +#### Data Splitting Logic +- **Automatic training assignment**: Cell types not mentioned in `[zeroshot]` participate in training. Within each cell type, perturbations not listed in `[fewshot]` go to the training set. +- **Overlapping splits**: Perturbations can appear in both validation and test sets within fewshot configurations. +- **Dataset naming convention**: Use `"dataset_name.cell_type"` format when specifying cell types in zeroshot and fewshot sections. +- **File requirements**: Dataset paths should point to directories containing `.h5ad` files. +- **Control perturbations**: Ensure your control condition (specified via `control_pert` parameter) exists across all splits. -### Validation +#### Required vs Optional Sections +- **`[datasets]`**: Required - defines dataset paths +- **`[training]`**: Required - specifies which datasets participate in training +- **`[zeroshot]`**: Optional - reserves entire cell types for evaluation +- **`[fewshot]`**: Optional - specifies perturbation-level splits within cell types -The configuration system will validate that: -- All referenced datasets exist at the specified paths -- Cell types mentioned in zeroshot/fewshot sections exist in the datasets -- Perturbations listed in fewshot sections are present in the corresponding cell types -- No conflicts exist between zeroshot and fewshot assignments for the same cell type +### Configuration Validation + +The system validates configurations and will raise errors if: +- Referenced datasets don't exist at specified paths +- Cell types mentioned in `[zeroshot]` or `[fewshot]` don't exist in the datasets +- Perturbations listed in `[fewshot]` are not present in the corresponding cell types +- There are conflicts between zeroshot and fewshot assignments for the same cell type ## State Embedding Model (SE) @@ -296,16 +309,55 @@ To run inference with a trained State checkpoint, e.g., the State trained to 16 ```bash state emb transform \ - --model-folder /large_storage/ctc/userspace/aadduri/SE-600M \ - --checkpoint /large_storage/ctc/userspace/aadduri/SE-600M/se600m_epoch15.ckpt \ - --input /large_storage/ctc/datasets/replogle/rpe1_raw_singlecell_01.h5ad \ - --output /home/aadduri/vci_pretrain/test_output.h5ad + --checkpoint /path/to/model.ckpt \ + --input /path/to/input_data.h5ad \ + --output /path/to/output_embeddings.h5ad \ + --embed-key X_state +``` + +Or using a model folder (automatically finds latest checkpoint): + +```bash +state emb transform \ + --model-folder /path/to/model_directory \ + --input /path/to/input_data.h5ad \ + --output /path/to/output_embeddings.h5ad ``` Notes on the h5ad file format: - CSR matrix format is required - `gene_name` is required in the `var` dataframe +#### Preprocess datasets for embedding training + +Create embedding profiles and prepare datasets for training: + +```bash +state emb preprocess \ + --profile-name my_profile \ + --train-csv train_datasets.csv \ + --val-csv val_datasets.csv \ + --output-dir embeddings/ \ + --num-threads 4 +``` + +This creates embedding profiles and prepares CSV files with dataset mappings for training. + +#### Evaluate embedding models + +Evaluate trained embedding models on differential expression prediction: + +```bash +state emb eval \ + --checkpoint /path/to/model.ckpt \ + --adata /path/to/test_data.h5ad \ + --pert-col gene \ + --control-pert non-targeting \ + --gene-column gene_name +``` + +This computes gene overlap metrics, ROC curves, and precision-recall curves for evaluating embedding model performance. + ### Vector Database Install the optional dependencies: @@ -324,34 +376,35 @@ uv sync --extra vectordb ```bash state emb transform \ - --model-folder /large_storage/ctc/userspace/aadduri/SE-600M \ - --input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532045.h5ad \ - --lancedb tmp/state_embeddings.lancedb \ - --gene-column gene_symbols + --checkpoint /path/to/model.ckpt \ + --input /path/to/dataset.h5ad \ + --lancedb /path/to/vector_database.lancedb \ + --gene-column gene_name ``` -Running this command multiple times with the same lancedb appends the new data to the provided database. +Running this command multiple times with the same `--lancedb` path appends the new data to the existing database. #### Query the database -Obtain the embeddings: +First, obtain embeddings for your query cells: ```bash state emb transform \ - --model-folder /large_storage/ctc/userspace/aadduri/SE-600M \ - --input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532046.h5ad \ - --output tmp/SRX27532046.h5ad \ - --gene-column gene_symbols + --checkpoint /path/to/model.ckpt \ + --input /path/to/query_cells.h5ad \ + --output /path/to/query_embeddings.h5ad \ + --gene-column gene_name ``` -Query the database with the embeddings: +Then query the database for similar cells: ```bash state emb query \ - --lancedb tmp/state_embeddings.lancedb \ - --input tmp/SRX27532046.h5ad \ - --output tmp/similar_cells.csv \ + --lancedb /path/to/vector_database.lancedb \ + --input /path/to/query_embeddings.h5ad \ + --output /path/to/similar_cells.csv \ --k 3 +``` # Singularity