Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "arc-state"
version = "0.9.9"
version = "0.9.10"
description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts."
readme = "README.md"
authors = [
Expand All @@ -13,7 +13,7 @@ authors = [
requires-python = ">=3.10,<3.13"
dependencies = [
"anndata>=0.11.4",
"cell-load>=0.7.5",
"cell-load>=0.7.6",
"numpy>=2.2.6",
"pandas>=2.2.3",
"pyyaml>=6.0.2",
Expand Down
216 changes: 154 additions & 62 deletions src/state/_cli/_tx/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ def add_arguments_infer(parser: argparse.ArgumentParser):
parser.add_argument(
"--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels"
)
parser.add_argument(
"--batch_col", type=str, default="batch_var", help="Column in adata.obs batch labels"
)
parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)")
parser.add_argument(
"--model_dir",
Expand All @@ -27,6 +30,7 @@ def add_arguments_infer(parser: argparse.ArgumentParser):
"--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)"
)
parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)")
parser.add_argument("--seed", type=int, default=None, help="Random seed for reproducible cell shuffling within perturbation groups")


def run_tx_infer(args):
Expand All @@ -38,13 +42,19 @@ def run_tx_infer(args):
import scanpy as sc
import torch
import yaml
from lightning import seed_everything
from tqdm import tqdm

from ...tx.models.pert_sets import PertSetsPerturbationModel

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Set random seed if provided
if args.seed is not None:
seed_everything(args.seed)
logger.info(f"Set random seed to {args.seed}")

def load_config(cfg_path: str) -> dict:
"""Load config from the YAML file that was dumped during training."""
if not os.path.exists(cfg_path):
Expand Down Expand Up @@ -73,7 +83,7 @@ def load_config(cfg_path: str) -> dict:

# Load model
logger.info(f"Loading model from checkpoint: {checkpoint_path}")
model = PertSetsPerturbationModel.load_from_checkpoint(checkpoint_path)
model = PertSetsPerturbationModel.load_from_checkpoint(checkpoint_path, strict=False, **cfg['model']['kwargs'])
model.eval()
cell_sentence_len = model.cell_sentence_len
device = next(model.parameters()).device
Expand Down Expand Up @@ -122,6 +132,44 @@ def load_config(cfg_path: str) -> dict:
logger.info(f"AnnData has {len(unique_pert_names)} unique perturbations")
logger.info(f"First 10 perturbations in AnnData: {unique_pert_names[:10]}")

# Load batch mapping from torch file
batch_onehot_map_path = os.path.join(args.model_dir, "batch_onehot_map.pkl")
batch_onehot_map = pickle.load(open(batch_onehot_map_path, 'rb'))

# Create batch name to index mapping
batch_names = adata.obs[args.batch_col].values
batch_name_to_idx = {name: idx for idx, name in enumerate(sorted(batch_onehot_map.keys()))}

# Prepare batch indices tensor
batch_indices_tensor = torch.zeros(len(batch_names), dtype=torch.long, device="cpu")

logger.info(f"Data module has {len(batch_onehot_map)} batches in mapping")
logger.info(f"First 10 batches in data module: {list(batch_onehot_map.keys())[:10]}")

unique_batch_names = sorted(set(batch_names))
logger.info(f"AnnData has {len(unique_batch_names)} unique batches")
logger.info(f"First 10 batches in AnnData: {unique_batch_names[:10]}")

# Check overlap for batches
batch_overlap = set(unique_batch_names) & set(batch_onehot_map.keys())
logger.info(f"Overlap between AnnData and data module batches: {len(batch_overlap)} batches")
if len(batch_overlap) < len(unique_batch_names):
missing_batches = set(unique_batch_names) - set(batch_onehot_map.keys())
logger.warning(f"Missing batches: {list(missing_batches)[:10]}")

# Fill batch indices
batch_matched_count = 0
default_batch_idx = 0 # Use first batch as default
for idx, name in enumerate(batch_names):
if name in batch_name_to_idx:
batch_indices_tensor[idx] = batch_name_to_idx[name]
batch_matched_count += 1
else:
# Use first available batch as fallback
batch_indices_tensor[idx] = default_batch_idx

logger.info(f"Matched {batch_matched_count} out of {len(batch_names)} batches")

# Check overlap
overlap = set(unique_pert_names) & set(pert_onehot_map.keys())
logger.info(f"Overlap between AnnData and data module: {len(overlap)} perturbations")
Expand Down Expand Up @@ -151,80 +199,124 @@ def load_config(cfg_path: str) -> dict:

logger.info(f"Matched {matched_count} out of {len(pert_names)} perturbations")

# Process in batches with progress bar
# Use cell_sentence_len as batch size since model expects this
# Group cells by perturbation to ensure batches contain cells from same perturbation
import pandas as pd

# Create a DataFrame for easier grouping
df = pd.DataFrame({"pert_name": pert_names, "index": range(len(pert_names))})
grouped = df.groupby("pert_name")

n_samples = len(pert_names)
batch_size = cell_sentence_len # Model requires this exact batch size
n_batches = (n_samples + batch_size - 1) // batch_size # Ceiling division

batch_size = args.batch_size # Use user-specified batch size

logger.info(
f"Running inference on {n_samples} samples in {n_batches} batches of size {batch_size} (model's cell_sentence_len)..."
f"Running inference on {n_samples} samples grouped by perturbation with batch size {batch_size}..."
)

all_preds = []
processed_samples = 0
batch_idx = 0

with torch.no_grad():
progress_bar = tqdm(total=n_samples, desc="Processing samples", unit="samples")

for batch_idx in range(n_batches):
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, n_samples)
current_batch_size = end_idx - start_idx

# Get batch data
X_batch = torch.tensor(X[start_idx:end_idx], dtype=torch.float32).to(device)
pert_batch = pert_tensor[start_idx:end_idx].to(device)
pert_names_batch = pert_names[start_idx:end_idx].tolist()

# Pad the batch to cell_sentence_len if it's the last incomplete batch
if current_batch_size < cell_sentence_len:
# Pad with zeros for embeddings
padding_size = cell_sentence_len - current_batch_size
X_pad = torch.zeros((padding_size, X_batch.shape[1]), device=device)
X_batch = torch.cat([X_batch, X_pad], dim=0)

# Pad perturbation tensor with control perturbation
pert_pad = torch.zeros((padding_size, pert_batch.shape[1]), device=device)
if control_pert in pert_onehot_map:
pert_pad[:] = pert_onehot_map[control_pert].to(device)
# Process each perturbation group
for pert_name, group in grouped:
indices = group["index"].values

# Randomize the order of cells within this perturbation group
if args.seed is not None:
np.random.shuffle(indices)

group_size = len(indices)

# Process this perturbation group in batches
for i in range(0, group_size, batch_size):
batch_indices = indices[i:i + batch_size]
current_batch_size = len(batch_indices)

# Check if this is an incomplete batch that needs sampling with replacement
if current_batch_size < batch_size:
# Sample with replacement to fill out the batch
additional_samples_needed = batch_size - current_batch_size
replacement_indices = np.random.choice(batch_indices, size=additional_samples_needed, replace=True)

# Combine original indices with replacement indices
extended_batch_indices = np.concatenate([batch_indices, replacement_indices])

# Get batch data for the extended batch
X_batch = torch.tensor(X[extended_batch_indices], dtype=torch.float32).to(device)
pert_batch = pert_tensor[extended_batch_indices].to(device)
batch_idx_batch = batch_indices_tensor[extended_batch_indices].to(device)
pert_names_batch = [pert_names[idx] for idx in extended_batch_indices]

# Prepare batch
batch = {
"ctrl_cell_emb": X_batch,
"pert_emb": pert_batch,
"pert_name": pert_names_batch,
"batch": batch_idx_batch.unsqueeze(0), # Shape: (1, batch_size)
}

# Run inference on batch
batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False)

# Extract predictions from the dictionary returned by predict_step
if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None:
# Use gene space predictions (from decoder)
pred_tensor = batch_preds["pert_cell_counts_preds"]
else:
# Use latent space predictions
pred_tensor = batch_preds["preds"]

# Only keep predictions for the original samples (not the replacement samples)
original_preds = pred_tensor[:current_batch_size]

# Store predictions with their original indices to maintain order
batch_preds_with_indices = [(batch_indices[j], original_preds[j].cpu().numpy()) for j in range(current_batch_size)]
all_preds.extend(batch_preds_with_indices)

else:
pert_pad[:, 0] = 1 # Default to first perturbation
pert_batch = torch.cat([pert_batch, pert_pad], dim=0)

# Extend perturbation names
pert_names_batch.extend([control_pert] * padding_size)

# Prepare batch - use same format as working code
batch = {
"ctrl_cell_emb": X_batch,
"pert_emb": pert_batch, # Keep as 2D tensor
"pert_name": pert_names_batch,
"batch": torch.zeros((1, cell_sentence_len), device=device), # Use (1, cell_sentence_len)
}

# Run inference on batch using padded=False like in working code
batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False)

# Extract predictions from the dictionary returned by predict_step
# Use gene decoder output if available, otherwise use latent predictions
if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None:
# Use gene space predictions (from decoder)
pred_tensor = batch_preds["pert_cell_counts_preds"]
else:
# Use latent space predictions
pred_tensor = batch_preds["preds"]

# Only keep predictions for the actual samples (not padding)
actual_preds = pred_tensor[:current_batch_size]
all_preds.append(actual_preds.cpu().numpy())

# Update progress bar
progress_bar.update(current_batch_size)
# Full batch - process normally
# Get batch data
X_batch = torch.tensor(X[batch_indices], dtype=torch.float32).to(device)
pert_batch = pert_tensor[batch_indices].to(device)
batch_idx_batch = batch_indices_tensor[batch_indices].to(device)
pert_names_batch = [pert_names[idx] for idx in batch_indices]

# Prepare batch
batch = {
"ctrl_cell_emb": X_batch,
"pert_emb": pert_batch,
"pert_name": pert_names_batch,
"batch": batch_idx_batch.unsqueeze(0), # Shape: (1, current_batch_size)
}

# Run inference on batch
batch_preds = model.predict_step(batch, batch_idx=batch_idx, padded=False)

# Extract predictions from the dictionary returned by predict_step
if "pert_cell_counts_preds" in batch_preds and batch_preds["pert_cell_counts_preds"] is not None:
# Use gene space predictions (from decoder)
pred_tensor = batch_preds["pert_cell_counts_preds"]
else:
# Use latent space predictions
pred_tensor = batch_preds["preds"]

# Store predictions with their original indices to maintain order
batch_preds_with_indices = [(batch_indices[j], pred_tensor[j].cpu().numpy()) for j in range(current_batch_size)]
all_preds.extend(batch_preds_with_indices)

# Update progress bar
progress_bar.update(current_batch_size)
processed_samples += current_batch_size
batch_idx += 1

progress_bar.close()

# Concatenate all predictions
preds_np = np.concatenate(all_preds, axis=0)
# Sort predictions by original index to maintain input order
all_preds.sort(key=lambda x: x[0])
preds_np = np.array([pred for _, pred in all_preds])

# Save predictions to AnnData
adata.X = preds_np
Expand Down
1 change: 1 addition & 0 deletions src/state/configs/data/perturbation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ kwargs:
store_raw_basal: false
int_counts: false
barcode: true
zeroshot_controls: false
output_dir: null
debug: true