Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ notebooks/
*.slurm
temp
wandb/
tasks/
33 changes: 33 additions & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Repository Guidelines

## Project Structure & Module Organization
- Core library lives in `src/state`, with CLI entrypoints in `state/__main__.py` and subcommands under `state/_cli`.
- Model configs and resources sit in `src/state/configs`, embeddings helpers in `src/state/emb`, and transition utilities in `src/state/tx`.
- Tests reside in `tests/` (`test_*.py`), runnable without extra fixtures. Example TOML configs are in `examples/`. Helper scripts live in `scripts/` for inference and embedding.
- Artifacts or scratch outputs should go in `tmp/` or a user-created path; keep `assets/` for checked-in visuals/resources only.

## Build, Test, and Development Commands
- Create/activate env and install in editable mode: `uv tool install -e .`.
- Run the CLI: `uv run state --help` (entrypoints `emb` and `tx`).
- Format/lint: `uv run ruff check .` (auto-fixes enabled by default config).
- Run tests: `uv run pytest` (adds `src/` to `PYTHONPATH` via standard layout).

## Coding Style & Naming Conventions
- Python 3.10–3.12; prefer type hints on public functions.
- Use 4-space indentation, 120-char max line length (`ruff.toml`), and avoid bare `except` (E722 is explicitly ignored—only use when necessary).
- Modules and files use `snake_case`; classes `CamelCase`; constants `UPPER_SNAKE_CASE`.
- Keep CLI options descriptive and align new configs with the existing TOML examples.

## Testing Guidelines
- Add unit tests alongside new features in `tests/` with filenames `test_*.py` and functions `test_*`.
- Cover edge cases around data loading, config parsing, and checkpoint handling; favor small fixtures over large data blobs.
- For regressions, reproduce with a failing test first, then implement the fix.

## Commit & Pull Request Guidelines
- Follow the short, imperative style seen in history (`chore: …`, `patch: …`, or focused message without trailing punctuation). Reference issue/PR numbers where applicable.
- PRs should explain the change, risks, and testing done (`uv run pytest`, `uv run ruff check .`). Include CLI examples if you changed commands or configs.
- Keep diffs scoped; split unrelated changes into separate PRs. Include screenshots or logs only when UI/output changes are relevant.

## Security & Configuration Tips
- Do not commit dataset paths or secrets; use environment variables or local config files kept out of git.
- Validate file paths in new CLI options and prefer existing config loaders under `state/_cli` to avoid duplicating logic.
140 changes: 108 additions & 32 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,13 @@ in the TOML file:


```bash
state tx infer --output $HOME/state/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/final.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg
state tx infer \
--model-dir /path/to/model/ \
--checkpoint /path/to/model/final.ckpt \
--adata /path/to/anndata/processed.h5ad \
--pert-col gene \
--embed-key X_hvg \
--output /path/to/output/simulated.h5ad
```

Here, `/path/to/model/` is the folder downloaded from [HuggingFace](https://huggingface.co/arcinstitute).
Expand All @@ -129,6 +135,29 @@ This command:
- Applies log1p transformation (`sc.pp.log1p`)
- Identifies highly variable genes (`sc.pp.highly_variable_genes`)
- Stores the HVG expression matrix in `.obsm['X_hvg']`
- Stores HVG gene names in `.uns['X_hvg_var_names']` for downstream mapping
- **Inference/Prediction**: `tx infer` and `tx predict` commands preserve HVG gene names in output AnnData when using `--embed-key X_hvg`

#### Accessing HVG Gene Names

The HVG gene names associated with `adata.obsm["X_hvg"]` are stored in `adata.uns["X_hvg_var_names"]`.
This applies to:
- **Training data**: Output from `preprocess_train`
- **Inference results**: Output from `tx infer` with `--embed-key X_hvg`
- **Prediction results**: Output from `tx predict` (both prediction and ground truth AnnData)

**Note**: If HVG gene names are missing during inference, a warning will be displayed recommending re-running `preprocess_train` with the latest STATE version.

This makes it easy to construct downstream AnnData objects for tools like `pdex`:

```python
hvg_names = adata.uns.get("X_hvg_var_names")
adata_for_pdex = ad.AnnData(
X=adata.obsm["X_hvg"],
obs=adata.obs,
var=pd.DataFrame(index=hvg_names),
)
```

#### Inference Data Preprocessing

Expand All @@ -138,8 +167,8 @@ Use `preprocess_infer` to create a "control template" for model inference:
state tx preprocess_infer \
--adata /path/to/real_data.h5ad \
--output /path/to/control_template.h5ad \
--control_condition "DMSO" \
--pert_col "treatment" \
--control-condition "DMSO" \
--pert-col "treatment" \
--seed 42
```

Expand Down Expand Up @@ -246,19 +275,26 @@ test = ["MYC", "TP53"]

### Important Notes

- **Automatic training assignment**: Any cell type not mentioned in `[zeroshot]` automatically participates in training, with perturbations not listed in `[fewshot]` going to the training set
- **Overlapping splits**: Perturbations can appear in both validation and test sets within fewshot configurations
- **Dataset naming**: Use the format `"dataset_name.cell_type"` when specifying cell types in zeroshot and fewshot sections
- **Path requirements**: Dataset paths should point to directories containing h5ad files
- **Control perturbations**: Ensure your control condition (specified via `control_pert` parameter) is available across all splits
#### Data Splitting Logic
- **Automatic training assignment**: Cell types not mentioned in `[zeroshot]` participate in training. Within each cell type, perturbations not listed in `[fewshot]` go to the training set.
- **Overlapping splits**: Perturbations can appear in both validation and test sets within fewshot configurations.
- **Dataset naming convention**: Use `"dataset_name.cell_type"` format when specifying cell types in zeroshot and fewshot sections.
- **File requirements**: Dataset paths should point to directories containing `.h5ad` files.
- **Control perturbations**: Ensure your control condition (specified via `control_pert` parameter) exists across all splits.

#### Required vs Optional Sections
- **`[datasets]`**: Required - defines dataset paths
- **`[training]`**: Required - specifies which datasets participate in training
- **`[zeroshot]`**: Optional - reserves entire cell types for evaluation
- **`[fewshot]`**: Optional - specifies perturbation-level splits within cell types

### Validation
### Configuration Validation

The configuration system will validate that:
- All referenced datasets exist at the specified paths
- Cell types mentioned in zeroshot/fewshot sections exist in the datasets
- Perturbations listed in fewshot sections are present in the corresponding cell types
- No conflicts exist between zeroshot and fewshot assignments for the same cell type
The system validates configurations and will raise errors if:
- Referenced datasets don't exist at specified paths
- Cell types mentioned in `[zeroshot]` or `[fewshot]` don't exist in the datasets
- Perturbations listed in `[fewshot]` are not present in the corresponding cell types
- There are conflicts between zeroshot and fewshot assignments for the same cell type


## State Embedding Model (SE)
Expand All @@ -273,16 +309,55 @@ To run inference with a trained State checkpoint, e.g., the State trained to 16

```bash
state emb transform \
--model-folder /large_storage/ctc/userspace/aadduri/SE-600M \
--checkpoint /large_storage/ctc/userspace/aadduri/SE-600M/se600m_epoch15.ckpt \
--input /large_storage/ctc/datasets/replogle/rpe1_raw_singlecell_01.h5ad \
--output /home/aadduri/vci_pretrain/test_output.h5ad
--checkpoint /path/to/model.ckpt \
--input /path/to/input_data.h5ad \
--output /path/to/output_embeddings.h5ad \
--embed-key X_state
```

Or using a model folder (automatically finds latest checkpoint):

```bash
state emb transform \
--model-folder /path/to/model_directory \
--input /path/to/input_data.h5ad \
--output /path/to/output_embeddings.h5ad
```

Notes on the h5ad file format:
- CSR matrix format is required
- `gene_name` is required in the `var` dataframe

#### Preprocess datasets for embedding training

Create embedding profiles and prepare datasets for training:

```bash
state emb preprocess \
--profile-name my_profile \
--train-csv train_datasets.csv \
--val-csv val_datasets.csv \
--output-dir embeddings/ \
--num-threads 4
```

This creates embedding profiles and prepares CSV files with dataset mappings for training.

#### Evaluate embedding models

Evaluate trained embedding models on differential expression prediction:

```bash
state emb eval \
--checkpoint /path/to/model.ckpt \
--adata /path/to/test_data.h5ad \
--pert-col gene \
--control-pert non-targeting \
--gene-column gene_name
```

This computes gene overlap metrics, ROC curves, and precision-recall curves for evaluating embedding model performance.

### Vector Database

Install the optional dependencies:
Expand All @@ -301,34 +376,35 @@ uv sync --extra vectordb

```bash
state emb transform \
--model-folder /large_storage/ctc/userspace/aadduri/SE-600M \
--input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532045.h5ad \
--lancedb tmp/state_embeddings.lancedb \
--gene-column gene_symbols
--checkpoint /path/to/model.ckpt \
--input /path/to/dataset.h5ad \
--lancedb /path/to/vector_database.lancedb \
--gene-column gene_name
```

Running this command multiple times with the same lancedb appends the new data to the provided database.
Running this command multiple times with the same `--lancedb` path appends the new data to the existing database.

#### Query the database

Obtain the embeddings:
First, obtain embeddings for your query cells:

```bash
state emb transform \
--model-folder /large_storage/ctc/userspace/aadduri/SE-600M \
--input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532046.h5ad \
--output tmp/SRX27532046.h5ad \
--gene-column gene_symbols
--checkpoint /path/to/model.ckpt \
--input /path/to/query_cells.h5ad \
--output /path/to/query_embeddings.h5ad \
--gene-column gene_name
```

Query the database with the embeddings:
Then query the database for similar cells:

```bash
state emb query \
--lancedb tmp/state_embeddings.lancedb \
--input tmp/SRX27532046.h5ad \
--output tmp/similar_cells.csv \
--lancedb /path/to/vector_database.lancedb \
--input /path/to/query_embeddings.h5ad \
--output /path/to/similar_cells.csv \
--k 3
```

# Singularity

Expand Down
64 changes: 64 additions & 0 deletions docs/migration/hvg_var_names.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Migration: HVG Gene Names Stored in AnnData Uns

## Summary

Recent versions of STATE store highly variable gene (HVG) names in `adata.uns["X_hvg_var_names"]`.
This makes it possible for downstream tools to map `adata.obsm["X_hvg"]` columns back to gene IDs.

If you have preprocessed data created before this change, you can backfill the HVG names with the
script below.

## Backward Compatibility

This change is fully backward compatible:

- **Existing preprocessed data**: Inference commands continue to work without modification. A
non-blocking warning is emitted recommending re-preprocessing, but execution proceeds normally.
- **Existing trained models**: Model checkpoints do not depend on this uns key. Gene names are
already captured in `var_dims.pkl` at training time.
- **Downstream code**: Code unaware of `X_hvg_var_names` simply ignores it. The obsm matrix
structure is unchanged.

### Fallback Behavior

When `X_hvg_var_names` is absent, STATE attempts to recover gene names from
`adata.var_names[adata.var.highly_variable]`. This fallback succeeds as long as the
`highly_variable` boolean column remains in `adata.var`.

### When Gene Names Are Unrecoverable

Gene names cannot be recovered if an h5ad file has `X_hvg` in obsm but:

1. No `X_hvg_var_names` in uns, AND
2. No `highly_variable` column in var (e.g., var was subset or modified)

This edge case would already be broken prior to this change. The new feature makes the mapping
explicit rather than implicit.

## Backfill Script

For existing preprocessed files, run the following to add `X_hvg_var_names`:
```python
import anndata as ad
import numpy as np

adata = ad.read_h5ad("your_preprocessed_data.h5ad")

if "X_hvg" in adata.obsm and "X_hvg_var_names" not in adata.uns:
if "highly_variable" in adata.var.columns:
hvg_names = adata.var_names[adata.var.highly_variable].tolist()
adata.uns["X_hvg_var_names"] = np.array(hvg_names, dtype=object)
adata.write_h5ad("your_preprocessed_data.h5ad")
print(f"Added {len(hvg_names)} HVG names to uns")
else:
print("Cannot backfill: 'highly_variable' column not found in adata.var")
else:
print("Backfill not needed or X_hvg not present")
```

## Notes

- The uns key is stored as a NumPy array of Python strings for h5ad compatibility.
- Re-running `state tx preprocess_train` with the latest version will populate this automatically.
- The naming convention `{obsm_key}_var_names` allows for multiple obsm matrices with associated
gene names (e.g., `X_pca_var_names` if needed in the future).
32 changes: 32 additions & 0 deletions src/state/_cli/_tx/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def add_arguments_infer(parser: argparse.ArgumentParser):
action="store_true",
help="Reduce logging verbosity.",
)
parser.add_argument(
"--verbose",
action="store_true",
help="Show extra details about gene name mapping.",
)
parser.add_argument(
"--tsv",
type=str,
Expand Down Expand Up @@ -119,6 +124,8 @@ def run_tx_infer(args: argparse.Namespace):
from tqdm import tqdm

from ...tx.models.state_transition import StateTransitionPerturbationModel
from ...tx.constants import HVG_VAR_NAMES_KEY
from ...tx.utils.hvg import get_hvg_var_names

# -----------------------
# Helpers
Expand Down Expand Up @@ -422,6 +429,26 @@ def pad_adata_with_tsv(
# -----------------------
adata = sc.read_h5ad(args.adata)

hvg_names = None
hvg_names_status = "n/a"
if args.embed_key == "X_hvg":
hvg_names = get_hvg_var_names(adata, obsm_key="X_hvg")
if hvg_names is None and not args.quiet:
print(
"Warning: adata.uns['X_hvg_var_names'] not found. "
"Downstream analysis (e.g., pdex) may not be able to map predictions to gene names. "
"Consider re-running preprocess_train with the latest STATE version."
)
if hvg_names is not None:
hvg_names_status = "present"
else:
hvg_names_status = "missing"
if args.verbose and not args.quiet:
if hvg_names is not None:
print(f"HVG gene names found for X_hvg: {len(hvg_names)} entries.")
else:
print("HVG gene names not found for X_hvg.")

# optional TSV padding mode - pad with additional perturbation cells
if args.tsv:
if not args.quiet:
Expand Down Expand Up @@ -904,6 +931,10 @@ def group_control_indices(group_name: str) -> np.ndarray:
elif output_space == "all":
adata.X = sim_counts

# Store HVG names if available
if hvg_names is not None:
adata.uns[HVG_VAR_NAMES_KEY] = np.array(hvg_names, dtype=object)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing dimension validation when storing HVG names in infer

Medium Severity

When there's a dimension mismatch between the input obsm["X_hvg"] and model output (lines 840-847), sim_counts is reinitialized with the model's output dimension. However, at line 936, hvg_names (retrieved earlier from the input with potentially different length) is stored unconditionally without validating dimensions. This differs from _predict.py which validates len(hvg_uns_names) == final_preds.shape[1] before storing. The result could be an output file where uns["X_hvg_var_names"] length doesn't match obsm["X_hvg"] columns, causing incorrect gene mappings for downstream tools.

Additional Locations (1)

Fix in Cursor Fix in Web


if output_is_npy:
if pred_matrix is None:
raise ValueError("Predictions matrix is unavailable; cannot write .npy output")
Expand All @@ -927,3 +958,4 @@ def group_control_indices(group_name: str) -> np.ndarray:
print(f"Saved: {output_path}")
if counts_written and counts_out_target:
print(f"Saved count predictions to adata.{counts_out_target}")
print(f"HVG names: {hvg_names_status}")
Loading