From 59b49a2cd05c85c6606a580f45a3d5968bd96626 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 22 Jul 2025 20:36:21 -0700 Subject: [PATCH 1/8] Enhance README and CLI for embedding commands - Updated README to clarify that existing cell records will be updated with new embeddings and provided example dataset details. - Improved CLI argument parsing by adding custom descriptions for embedding and transcriptomic commands. - Introduced `CustomFormatter` for better help message formatting. - Added `max-workers` argument for parallel processing in embedding queries. - Refined result formatting in `run_emb_query` to include additional metadata and renamed columns for clarity. - Updated `StateVectorDB` to support merging and updating entries in LanceDB. --- README.md | 19 +++++- src/state/__main__.py | 20 ++++++- src/state/_cli/_emb/__init__.py | 2 +- src/state/_cli/_emb/_query.py | 22 +++++-- src/state/_cli/_emb/_transform.py | 15 +++-- src/state/_cli/_utils.py | 5 ++ src/state/emb/inference.py | 14 ++--- src/state/emb/vectordb.py | 97 ++++++++++++++++++++++--------- 8 files changed, 139 insertions(+), 55 deletions(-) create mode 100644 src/state/_cli/_utils.py diff --git a/README.md b/README.md index a78458f8..c72002fa 100644 --- a/README.md +++ b/README.md @@ -265,16 +265,19 @@ state emb transform \ ``` Running this command multiple times with the same lancedb appends the new data to the provided database. +Existing cell records will be updated with the new embeddings. #### Query the database +> For this example, we will use the same dataset (SRX27532045), so the top hit should be the same cell. + Obtain the embeddings: ```bash state emb transform \ --model-folder /large_storage/ctc/userspace/aadduri/SE-600M \ - --input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532046.h5ad \ - --output tmp/SRX27532046.h5ad \ + --input /large_storage/ctc/public/scBasecamp/GeneFull_Ex50pAS/GeneFull_Ex50pAS/Homo_sapiens/SRX27532045.h5ad \ + --output tmp/SRX27532045.h5ad \ --gene-column gene_symbols ``` @@ -283,9 +286,19 @@ Query the database with the embeddings: ```bash state emb query \ --lancedb tmp/state_embeddings.lancedb \ - --input tmp/SRX27532046.h5ad \ + --input tmp/SRX27532045.h5ad \ --output tmp/similar_cells.csv \ --k 3 +``` + +Output: + - `query_cell_id` : The cell id of the query cell + - `subject_rank` : The rank of the h (smallest distance to) + - `query_subject_distance` : The distance between the query and subject cell vectors + - `subject_cell_id` : The cell id of the hit cell + - `subject_dataset` : The dataset of the hit cell + - `embedding_key` : The embedding key of the hit cell + - `...` : Other `obs` metadata columns from the query cell # Singularity diff --git a/src/state/__main__.py b/src/state/__main__.py index e8094b32..1286ffa4 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -3,6 +3,7 @@ from hydra import compose, initialize from omegaconf import DictConfig +from ._cli._utils import CustomFormatter from ._cli import ( add_arguments_emb, add_arguments_tx, @@ -17,10 +18,23 @@ def get_args() -> tuple[ap.Namespace, list[str]]: """Parse known args and return remaining args for Hydra overrides""" - parser = ap.ArgumentParser() + desc = """description: + STATE command line interface. + For more information on how to use the CLI, use the `state --help` command.""" + parser = ap.ArgumentParser(description=desc, formatter_class=CustomFormatter) subparsers = parser.add_subparsers(required=True, dest="command") - add_arguments_emb(subparsers.add_parser("emb")) - add_arguments_tx(subparsers.add_parser("tx")) + + # emb + desc = """description: + Embedding commands. + For more information on how to use the CLI, use the `state emb --help` command.""" + add_arguments_emb(subparsers.add_parser("emb", description=desc, formatter_class=CustomFormatter)) + + # tx + desc = """description: + Transcriptomic commands. + For more information on how to use the CLI, use the `state tx --help` command.""" + add_arguments_tx(subparsers.add_parser("tx", description=desc, formatter_class=CustomFormatter)) # Use parse_known_args to get both known args and remaining args return parser.parse_args() diff --git a/src/state/_cli/_emb/__init__.py b/src/state/_cli/_emb/__init__.py index 1cda4d4d..cc4b94fd 100644 --- a/src/state/_cli/_emb/__init__.py +++ b/src/state/_cli/_emb/__init__.py @@ -8,7 +8,7 @@ def add_arguments_emb(parser: ap.ArgumentParser): - """""" + """Add embedding commands to the parser""" subparsers = parser.add_subparsers(required=True, dest="subcommand") add_arguments_fit(subparsers.add_parser("fit")) add_arguments_transform(subparsers.add_parser("transform")) diff --git a/src/state/_cli/_emb/_query.py b/src/state/_cli/_emb/_query.py index a10fe1e1..29d1cd39 100644 --- a/src/state/_cli/_emb/_query.py +++ b/src/state/_cli/_emb/_query.py @@ -1,3 +1,4 @@ +import os import argparse as ap import logging import pandas as pd @@ -15,8 +16,8 @@ def add_arguments_query(parser: ap.ArgumentParser): parser.add_argument("--exclude-distances", action="store_true", help="Exclude vector distances in results") parser.add_argument("--filter", type=str, help="Filter expression (e.g., 'cell_type==\"B cell\"')") - parser.add_argument("--batch-size", type=int, default=100, - help="Batch size for query operations") + parser.add_argument("--batch-size", type=int, default=100, help="Batch size for query operations") + parser.add_argument("--max-workers", type=int, default=os.cpu_count(), help="Maximum number of workers for parallel processing") def run_emb_query(args: ap.ArgumentParser): """ @@ -59,6 +60,7 @@ def run_emb_query(args: ap.ArgumentParser): filter=args.filter, include_distance=not args.exclude_distances, batch_size=args.batch_size, + max_workers=args.max_workers, show_progress=True ) @@ -66,11 +68,21 @@ def run_emb_query(args: ap.ArgumentParser): all_results = [] for query_idx, result_df in enumerate(results_list): result_df['query_cell_id'] = query_adata.obs.index[query_idx] - result_df['query_rank'] = range(1, len(result_df) + 1) + result_df['subject_rank'] = range(1, len(result_df) + 1) all_results.append(result_df) # Combine results final_results = pd.concat(all_results, ignore_index=True) + + # Format the results table + ## Move certain columns to the start, if they exist + to_move = ['query_cell_id', 'subject_rank', 'query_subject_distance', 'cell_id', 'dataset', 'embedding_key'] + to_move = [col for col in to_move if col in final_results.columns] + final_results = final_results[to_move + [col for col in final_results.columns if col not in to_move]] + ## Rename `cell_id` to 'subject_cell_id' + rn_dict = {'cell_id': 'subject_cell_id', 'dataset': 'subject_dataset'} + rn_dict = {k:v for k,v in rn_dict.items() if k in final_results.columns} + final_results = final_results.rename(columns=rn_dict) # Save results output_path = Path(args.output) @@ -96,11 +108,11 @@ def create_result_anndata(query_adata, results_df, k): cell_ids_array = np.array(cell_ids_pivot.values, dtype=str) # Handle distances - convert to float64 and handle missing values - if 'vector_distance' in results_df: + if 'query_subject_distance' in results_df: distances_pivot = results_df.pivot( index='query_cell_id', columns='query_rank', - values='vector_distance' + values='query_subject_distance' ) distances_array = np.array(distances_pivot.values, dtype=np.float64) else: diff --git a/src/state/_cli/_emb/_transform.py b/src/state/_cli/_emb/_transform.py index 5d4616a3..bf6aa25f 100644 --- a/src/state/_cli/_emb/_transform.py +++ b/src/state/_cli/_emb/_transform.py @@ -9,13 +9,12 @@ def add_arguments_transform(parser: ap.ArgumentParser): parser.add_argument("--output", required=False, help="Path to output embedded anndata file (h5ad)") parser.add_argument("--embed-key", default="X_state", help="Name of key to store embeddings") parser.add_argument("--gene-column", default="gene_name", help="Name of column in var dataframe to use for gene names") - parser.add_argument("--lancedb", type=str, help="Path to LanceDB database for vector storage") - parser.add_argument("--lancedb-update", action="store_true", - help="Update existing entries in LanceDB (default: append)") - parser.add_argument("--lancedb-batch-size", type=int, default=1000, - help="Batch size for LanceDB operations") - - + parser.add_argument("--dataset-name", type=str, default=None, help="Name of the dataset. If None, the input file name will be used.") + lancedb_group = parser.add_argument_group("Vector database options") + lancedb_group.add_argument("--lancedb", type=str, help="Path to LanceDB database for vector storage") + lancedb_group.add_argument("--lancedb-batch-size", type=int, default=1000, + help="Batch size for LanceDB operations") + def run_emb_transform(args: ap.ArgumentParser): """ Compute embeddings for an input anndata file using a pre-trained VCI model checkpoint. @@ -79,8 +78,8 @@ def run_emb_transform(args: ap.ArgumentParser): output_adata_path=args.output, emb_key=args.embed_key, gene_column=args.gene_column, + dataset_name=args.dataset_name, lancedb_path=args.lancedb, - update_lancedb=args.lancedb_update, lancedb_batch_size=args.lancedb_batch_size, ) diff --git a/src/state/_cli/_utils.py b/src/state/_cli/_utils.py new file mode 100644 index 00000000..6d2a7496 --- /dev/null +++ b/src/state/_cli/_utils.py @@ -0,0 +1,5 @@ +import argparse + +class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter, + argparse.RawDescriptionHelpFormatter): + pass \ No newline at end of file diff --git a/src/state/emb/inference.py b/src/state/emb/inference.py index d6fc8571..20a45d22 100644 --- a/src/state/emb/inference.py +++ b/src/state/emb/inference.py @@ -150,17 +150,15 @@ def encode_adata( output_adata_path: str | None = None, emb_key: str = "X_emb", dataset_name: str | None = None, - batch_size: int = 32, lancedb_path: str | None = None, - update_lancedb: bool = False, lancedb_batch_size: int = 1000, gene_column: str = "gene_name", ): - shape_dict = self.__load_dataset_meta(input_adata_path) - adata = anndata.read_h5ad(input_adata_path) if dataset_name is None: dataset_name = Path(input_adata_path).stem - + shape_dict = self.__load_dataset_meta(input_adata_path) + adata = anndata.read_h5ad(input_adata_path) + # Convert to CSR format if needed adata = self._convert_to_csr(adata) @@ -202,7 +200,7 @@ def encode_adata( if lancedb_path is not None: from .vectordb import StateVectorDB - log.info(f"Saving embeddings to LanceDB at {lancedb_path}") + log.info(f"Saving embeddings to LanceDB at {lancedb_path} using dataset name: {dataset_name}") vector_db = StateVectorDB(lancedb_path) # Extract relevant metadata @@ -213,8 +211,8 @@ def encode_adata( embeddings=all_embeddings, metadata=metadata, embedding_key=emb_key, - dataset_name=dataset_name or Path(input_adata_path).stem, - batch_size=lancedb_batch_size + dataset_name=dataset_name, + batch_size=lancedb_batch_size, ) log.info(f"Successfully saved {len(all_embeddings)} embeddings to LanceDB") diff --git a/src/state/emb/vectordb.py b/src/state/emb/vectordb.py index 073866df..cf223e03 100644 --- a/src/state/emb/vectordb.py +++ b/src/state/emb/vectordb.py @@ -3,6 +3,8 @@ import pandas as pd from typing import Optional, List from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed +from functools import partial class StateVectorDB: """Manages LanceDB operations for State embeddings.""" @@ -22,7 +24,7 @@ def create_or_update_table( metadata: pd.DataFrame, embedding_key: str = "X_state", dataset_name: Optional[str] = None, - batch_size: int = 1000 + batch_size: int = 1000, ): """Create or update the embeddings table. @@ -40,9 +42,10 @@ def create_or_update_table( batch_data = [] for j in range(i, batch_end): + cell_id = metadata.index[j] record = { "vector": embeddings[j].tolist(), - "cell_id": metadata.index[j], + "cell_id": cell_id, "embedding_key": embedding_key, "dataset": dataset_name or "unknown", **{col: metadata.iloc[j][col] for col in metadata.columns} @@ -54,7 +57,12 @@ def create_or_update_table( # Create or append to table if self.table_name in self.db.table_names(): table = self.db.open_table(self.table_name) - table.add(data) + ( + table.merge_insert(["cell_id", "dataset"]) + .when_matched_update_all() + .when_not_matched_insert_all() + .execute(data) + ) else: self.db.create_table(self.table_name, data=data) @@ -90,16 +98,17 @@ def search( if columns: query = query.select(columns + ['_distance'] if include_distance else columns) + # convert to pandas results = query.to_pandas() # deal with _distance column if '_distance' in results.columns: if include_distance: - results = results.rename(columns={'_distance': 'query_distance'}) + results = results.rename(columns={'_distance': 'query_subject_distance'}) else: results = results.drop('_distance', axis=1) elif include_distance: - results['query_distance'] = 0.0 + results['query_subject_distance'] = 0.0 # drop vector column if include_vector is False if not include_vector and 'vector' in results.columns: @@ -107,17 +116,29 @@ def search( return results + def _search_single(self, query_vector: np.ndarray, k: int, filter: str | None, + include_distance: bool, include_vector: bool): + """Helper method for parallel search.""" + return self.search( + query_vector=query_vector, + k=k, + filter=filter, + include_distance=include_distance, + include_vector=include_vector, + ) + def batch_search( self, query_vectors: np.ndarray, k: int = 10, filter: str | None = None, include_distance: bool = True, - batch_size: int = 100, + include_vector: bool = False, + max_workers: int = 4, + batch_size: int = 1000, show_progress: bool = True, - include_vector: bool = False ): - """Batch search for multiple query vectors. + """Parallel batch search for multiple query vectors using ThreadPoolExecutor. Args: query_vectors: Array of query embedding vectors @@ -125,35 +146,57 @@ def batch_search( filter: Optional filter expression include_distance: Whether to include distances include_vector: Whether to include the query vector in the results - batch_size: Number of queries to process at once + max_workers: Maximum number of worker threads + batch_size: Number of queries to submit to executor at once show_progress: Show progress bar Returns: List of DataFrames with search results """ from tqdm import tqdm - results = [] - iterator = range(0, len(query_vectors), batch_size) + # Create a partial function with fixed parameters + search_func = partial( + self._search_single, + k=k, + filter=filter, + include_distance=include_distance, + include_vector=include_vector, + ) - if show_progress: - iterator = tqdm(iterator, desc="Searching") + results = [None] * len(query_vectors) - for i in iterator: - batch_end = min(i + batch_size, len(query_vectors)) - batch_queries = query_vectors[i:batch_end] + # Process in batches to manage memory and avoid overwhelming the database + with ThreadPoolExecutor(max_workers=max_workers) as executor: + total_processed = 0 - batch_results = [] - for query_vec in batch_queries: - result = self.search( - query_vector=query_vec, - k=k, - filter=filter, - include_distance=include_distance, - include_vector=include_vector, - ) - batch_results.append(result) + if show_progress: + pbar = tqdm(total=len(query_vectors), desc="Searching") + + for batch_start in range(0, len(query_vectors), batch_size): + batch_end = min(batch_start + batch_size, len(query_vectors)) + batch_vectors = query_vectors[batch_start:batch_end] + + # Submit batch to executor + future_to_index = { + executor.submit(search_func, batch_vectors[i]): batch_start + i + for i in range(len(batch_vectors)) + } + + # Collect results for this batch + for future in as_completed(future_to_index): + index = future_to_index[future] + try: + results[index] = future.result() + except Exception as e: + print(f"Query {index} failed: {e}") + results[index] = pd.DataFrame() # Empty result on error + + total_processed += 1 + if show_progress: + pbar.update(1) - results.extend(batch_results) + if show_progress: + pbar.close() return results From 302cc6ac279f8a7abf5ebf40960f58fa318ac854 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Tue, 22 Jul 2025 20:44:33 -0700 Subject: [PATCH 2/8] Enhance CLI argument parsing for embedding and transcriptomic commands - Added custom descriptions for embedding commands in `add_arguments_emb`. - Improved help messages for query filtering in `add_arguments_query`. - Enhanced transcriptomic command descriptions in `add_arguments_tx` for clarity. - Introduced `CustomFormatter` for better formatting of help messages across CLI commands. --- src/state/_cli/_emb/__init__.py | 20 +++++++++++++++++--- src/state/_cli/_emb/_query.py | 3 ++- src/state/_cli/_tx/__init__.py | 27 +++++++++++++++++++++++---- src/state/_cli/_tx/_train.py | 2 +- 4 files changed, 43 insertions(+), 9 deletions(-) diff --git a/src/state/_cli/_emb/__init__.py b/src/state/_cli/_emb/__init__.py index cc4b94fd..49936d27 100644 --- a/src/state/_cli/_emb/__init__.py +++ b/src/state/_cli/_emb/__init__.py @@ -3,6 +3,7 @@ from ._fit import add_arguments_fit, run_emb_fit from ._transform import add_arguments_transform, run_emb_transform from ._query import add_arguments_query, run_emb_query +from .._utils import CustomFormatter __all__ = ["run_emb_fit", "run_emb_transform", "run_emb_query", "add_arguments_emb"] @@ -10,6 +11,19 @@ def add_arguments_emb(parser: ap.ArgumentParser): """Add embedding commands to the parser""" subparsers = parser.add_subparsers(required=True, dest="subcommand") - add_arguments_fit(subparsers.add_parser("fit")) - add_arguments_transform(subparsers.add_parser("transform")) - add_arguments_query(subparsers.add_parser("query")) + + # fit + desc = """description: + Fit an embedding model to a dataset.""" + add_arguments_fit(subparsers.add_parser("fit", description=desc, formatter_class=CustomFormatter)) + + # transform + desc = """description: + Transform a dataset into an embedding. + You can also create or add embeddings to a LanceDB database.""" + add_arguments_transform(subparsers.add_parser("transform", description=desc, formatter_class=CustomFormatter)) + + # query + desc = """description: + Query a LanceDB database for similar cells.""" + add_arguments_query(subparsers.add_parser("query", description=desc, formatter_class=CustomFormatter)) diff --git a/src/state/_cli/_emb/_query.py b/src/state/_cli/_emb/_query.py index 29d1cd39..16a3f378 100644 --- a/src/state/_cli/_emb/_query.py +++ b/src/state/_cli/_emb/_query.py @@ -15,7 +15,8 @@ def add_arguments_query(parser: ap.ArgumentParser): parser.add_argument("--embed-key", default="X_state", help="Key containing embeddings in input file") parser.add_argument("--exclude-distances", action="store_true", help="Exclude vector distances in results") - parser.add_argument("--filter", type=str, help="Filter expression (e.g., 'cell_type==\"B cell\"')") + parser.add_argument("--filter", type=str, + help="Filter expression (e.g., 'cell_type==\"B cell\"', assuming a 'cell_type' column exists in the database)") parser.add_argument("--batch-size", type=int, default=100, help="Batch size for query operations") parser.add_argument("--max-workers", type=int, default=os.cpu_count(), help="Maximum number of workers for parallel processing") diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index c03ee60b..0f4447f2 100644 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -3,13 +3,32 @@ from ._infer import add_arguments_infer, run_tx_infer from ._predict import add_arguments_predict, run_tx_predict from ._train import add_arguments_train, run_tx_train +from .._utils import CustomFormatter __all__ = ["run_tx_train", "run_tx_predict", "run_tx_infer", "add_arguments_tx"] def add_arguments_tx(parser: ap.ArgumentParser): - """""" + """Add transcriptomic commands to the parser""" subparsers = parser.add_subparsers(required=True, dest="subcommand") - add_arguments_train(subparsers.add_parser("train", add_help=False)) - add_arguments_predict(subparsers.add_parser("predict")) - add_arguments_infer(subparsers.add_parser("infer")) + + # Train + desc = """description: + Train a model using the specified configuration. + Use `hydra_overrides` to pass additional configuration overrides. + + For example, to use a different model architecture: + ``` + state tx train data.batch_size=32 model.kwargs.cell_set_len=64 + ```""" + add_arguments_train(subparsers.add_parser("train", description=desc, formatter_class=CustomFormatter)) + + # Predict + desc = """description: + Predict transcriptomic data using a trained model.""" + add_arguments_predict(subparsers.add_parser("predict", description=desc, formatter_class=CustomFormatter)) + + # Infer + desc = """description: + Infer transcriptomic data from input samples.""" + add_arguments_infer(subparsers.add_parser("infer", description=desc, formatter_class=CustomFormatter)) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 7e1b3531..71f7b4ea 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -7,7 +7,7 @@ def add_arguments_train(parser: ap.ArgumentParser): # Allow remaining args to be passed through to Hydra parser.add_argument("hydra_overrides", nargs="*", help="Hydra configuration overrides (e.g., data.batch_size=32)") # Add custom help handler - parser.add_argument("--help", "-h", action="store_true", help="Show configuration help with all parameters") + #parser.add_argument("--help", "-h", action="store_true", help="Show configuration help with all parameters") def run_tx_train(cfg: DictConfig): From 78cc2abd4bef04b528cad7a9ccd9dd1d5bfa2d15 Mon Sep 17 00:00:00 2001 From: Nick Youngblut Date: Wed, 23 Jul 2025 08:46:25 -0700 Subject: [PATCH 3/8] Improve CLI descriptions and fix minor issues --- src/state/__main__.py | 22 +++++++++++----------- src/state/_cli/_emb/__init__.py | 21 ++++++++++++++------- src/state/_cli/_emb/_query.py | 4 ++-- src/state/_cli/_tx/__init__.py | 26 ++++++++++++++------------ src/state/_cli/_utils.py | 9 ++++++--- 5 files changed, 47 insertions(+), 35 deletions(-) diff --git a/src/state/__main__.py b/src/state/__main__.py index 1286ffa4..62d36716 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -19,21 +19,22 @@ def get_args() -> tuple[ap.Namespace, list[str]]: """Parse known args and return remaining args for Hydra overrides""" desc = """description: - STATE command line interface. - For more information on how to use the CLI, use the `state --help` command.""" + Entry point for the STATE command line interface. + Use these commands to train models, compute embeddings, and run inference. + Run `state --help` for details on each command.""" parser = ap.ArgumentParser(description=desc, formatter_class=CustomFormatter) subparsers = parser.add_subparsers(required=True, dest="command") # emb desc = """description: - Embedding commands. - For more information on how to use the CLI, use the `state emb --help` command.""" + Commands for generating and querying STATE embeddings. + See `state emb --help` for subcommand options.""" add_arguments_emb(subparsers.add_parser("emb", description=desc, formatter_class=CustomFormatter)) # tx desc = """description: - Transcriptomic commands. - For more information on how to use the CLI, use the `state tx --help` command.""" + Train and evaluate perturbation models with Hydra configuration. + Overrides can be passed via `state tx param=value`.""" add_arguments_tx(subparsers.add_parser("tx", description=desc, formatter_class=CustomFormatter)) # Use parse_known_args to get both known args and remaining args @@ -74,21 +75,20 @@ def show_hydra_help(method: str): print() print("Usage examples:") print(" Override single parameter:") - print(f" uv run state tx train data.batch_size=64") + print(" uv run state tx train data.batch_size=64") print() print(" Override nested parameter:") - print(f" uv run state tx train model.kwargs.hidden_dim=512") + print(" uv run state tx train model.kwargs.hidden_dim=512") print() print(" Override multiple parameters:") - print(f" uv run state tx train data.batch_size=64 training.lr=0.001") + print(" uv run state tx train data.batch_size=64 training.lr=0.001") print() print(" Change config group:") - print(f" uv run state tx train data=custom_data model=custom_model") + print(" uv run state tx train data=custom_data model=custom_model") print() print("Available config groups:") # Show available config groups - import os from pathlib import Path config_dir = Path(__file__).parent / "configs" diff --git a/src/state/_cli/_emb/__init__.py b/src/state/_cli/_emb/__init__.py index 49936d27..f294321c 100644 --- a/src/state/_cli/_emb/__init__.py +++ b/src/state/_cli/_emb/__init__.py @@ -14,16 +14,23 @@ def add_arguments_emb(parser: ap.ArgumentParser): # fit desc = """description: - Fit an embedding model to a dataset.""" - add_arguments_fit(subparsers.add_parser("fit", description=desc, formatter_class=CustomFormatter)) + Train an embedding model on a reference dataset. + Provide Hydra overrides to adjust training parameters.""" + add_arguments_fit( + subparsers.add_parser("fit", description=desc, formatter_class=CustomFormatter) + ) # transform desc = """description: - Transform a dataset into an embedding. - You can also create or add embeddings to a LanceDB database.""" - add_arguments_transform(subparsers.add_parser("transform", description=desc, formatter_class=CustomFormatter)) + Encode an input dataset with a trained embedding model. + Results can be saved locally or inserted into a LanceDB database.""" + add_arguments_transform( + subparsers.add_parser("transform", description=desc, formatter_class=CustomFormatter) + ) # query desc = """description: - Query a LanceDB database for similar cells.""" - add_arguments_query(subparsers.add_parser("query", description=desc, formatter_class=CustomFormatter)) + Search a LanceDB vector store for cells with matching embeddings.""" + add_arguments_query( + subparsers.add_parser("query", description=desc, formatter_class=CustomFormatter) + ) diff --git a/src/state/_cli/_emb/_query.py b/src/state/_cli/_emb/_query.py index 16a3f378..ad66a3d0 100644 --- a/src/state/_cli/_emb/_query.py +++ b/src/state/_cli/_emb/_query.py @@ -131,5 +131,5 @@ def create_result_anndata(query_adata, results_df, k): # Create result anndata result_adata = query_adata.copy() result_adata.uns['lancedb_query_results'] = uns_data - - return result_adata \ No newline at end of file + + return result_adata diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index 0f4447f2..9b2776a7 100644 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -14,21 +14,23 @@ def add_arguments_tx(parser: ap.ArgumentParser): # Train desc = """description: - Train a model using the specified configuration. - Use `hydra_overrides` to pass additional configuration overrides. - - For example, to use a different model architecture: - ``` - state tx train data.batch_size=32 model.kwargs.cell_set_len=64 - ```""" - add_arguments_train(subparsers.add_parser("train", description=desc, formatter_class=CustomFormatter)) + Train a perturbation model using a Hydra configuration. + Provide overrides to customize training, e.g.: + `state tx train data.batch_size=32`""" + add_arguments_train( + subparsers.add_parser("train", description=desc, formatter_class=CustomFormatter) + ) # Predict desc = """description: - Predict transcriptomic data using a trained model.""" - add_arguments_predict(subparsers.add_parser("predict", description=desc, formatter_class=CustomFormatter)) + Generate predictions from a trained model and optionally compute evaluation metrics.""" + add_arguments_predict( + subparsers.add_parser("predict", description=desc, formatter_class=CustomFormatter) + ) # Infer desc = """description: - Infer transcriptomic data from input samples.""" - add_arguments_infer(subparsers.add_parser("infer", description=desc, formatter_class=CustomFormatter)) + Run inference on new samples using a trained model.""" + add_arguments_infer( + subparsers.add_parser("infer", description=desc, formatter_class=CustomFormatter) + ) diff --git a/src/state/_cli/_utils.py b/src/state/_cli/_utils.py index 6d2a7496..0ef6f032 100644 --- a/src/state/_cli/_utils.py +++ b/src/state/_cli/_utils.py @@ -1,5 +1,8 @@ import argparse -class CustomFormatter(argparse.ArgumentDefaultsHelpFormatter, - argparse.RawDescriptionHelpFormatter): - pass \ No newline at end of file +class CustomFormatter( + argparse.ArgumentDefaultsHelpFormatter, argparse.RawDescriptionHelpFormatter +): + """Combine default and raw formatting styles.""" + + pass From 8eff9e22bf2dae645912bf00dae77149f7f4989f Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Wed, 23 Jul 2025 09:08:51 -0700 Subject: [PATCH 4/8] Update CLI descriptions for embedding commands to clarify functionality - Revised the description for the `transform` command to specify that results can be inserted into a LanceDB vector store. - Enhanced the `query` command description to indicate it searches for cells with similar embeddings in a LanceDB vector store created with the `transform` command. --- src/state/_cli/_emb/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/state/_cli/_emb/__init__.py b/src/state/_cli/_emb/__init__.py index f294321c..d1a0a52a 100644 --- a/src/state/_cli/_emb/__init__.py +++ b/src/state/_cli/_emb/__init__.py @@ -23,14 +23,14 @@ def add_arguments_emb(parser: ap.ArgumentParser): # transform desc = """description: Encode an input dataset with a trained embedding model. - Results can be saved locally or inserted into a LanceDB database.""" + Results can be saved locally and inserted into a LanceDB vector store.""" add_arguments_transform( subparsers.add_parser("transform", description=desc, formatter_class=CustomFormatter) ) # query desc = """description: - Search a LanceDB vector store for cells with matching embeddings.""" + Search a LanceDB vector store (created with `transform`) for cells with similar embeddings.""" add_arguments_query( subparsers.add_parser("query", description=desc, formatter_class=CustomFormatter) ) From 5711017bde86f05f8d09dc439a7a7f8369a33218 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Wed, 23 Jul 2025 10:01:19 -0700 Subject: [PATCH 5/8] Add vector database summary functionality to CLI - Introduced `run_emb_vectordb` command to retrieve and display summary statistics of the LanceDB vector database. - Updated `README.md` with usage instructions and output format options for the new command. - Enhanced `StateVectorDB` class with `get_database_summary` method to compute and return comprehensive database statistics. - Added argument parsing for the new command in the CLI module. --- README.md | 22 +++++++++ src/state/__main__.py | 3 ++ src/state/_cli/__init__.py | 3 +- src/state/_cli/_emb/__init__.py | 10 +++- src/state/_cli/_emb/_vectordb.py | 79 ++++++++++++++++++++++++++++++++ src/state/emb/vectordb.py | 62 ++++++++++++++++++++++++- 6 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 src/state/_cli/_emb/_vectordb.py diff --git a/README.md b/README.md index c72002fa..6d1af66f 100644 --- a/README.md +++ b/README.md @@ -300,6 +300,28 @@ Output: - `embedding_key` : The embedding key of the hit cell - `...` : Other `obs` metadata columns from the query cell +#### Summarize the vector database + +Get comprehensive statistics about your vector database: + +```bash +state emb vectordb \ + --lancedb tmp/state_embeddings.lancedb \ + --format table +``` + +Output formats: + - `table` (default): Human-readable table format with emojis + - `json`: Machine-readable JSON format + - `yaml`: YAML format + +The summary includes: + - Total number of cells and datasets + - Number of unique embedding keys + - Embedding vector dimensions + - Cell count breakdown by dataset + - List of all embedding keys + # Singularity Containerization for STATE is available via the `singularity.def` file. diff --git a/src/state/__main__.py b/src/state/__main__.py index 62d36716..0eef3e24 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -10,6 +10,7 @@ run_emb_fit, run_emb_transform, run_emb_query, + run_emb_vectordb, run_tx_infer, run_tx_predict, run_tx_train, @@ -115,6 +116,8 @@ def main(): run_emb_transform(args) case "query": run_emb_query(args) + case "vectordb": + run_emb_vectordb(args) case "tx": match args.subcommand: case "train": diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py index 2de4e1d7..e341ee59 100644 --- a/src/state/_cli/__init__.py +++ b/src/state/_cli/__init__.py @@ -1,4 +1,4 @@ -from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query +from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_vectordb from ._tx import add_arguments_tx, run_tx_infer, run_tx_predict, run_tx_train __all__ = [ @@ -10,4 +10,5 @@ "run_emb_fit", "run_emb_query", "run_emb_transform", + "run_emb_vectordb", ] diff --git a/src/state/_cli/_emb/__init__.py b/src/state/_cli/_emb/__init__.py index d1a0a52a..5cbfa50b 100644 --- a/src/state/_cli/_emb/__init__.py +++ b/src/state/_cli/_emb/__init__.py @@ -3,9 +3,10 @@ from ._fit import add_arguments_fit, run_emb_fit from ._transform import add_arguments_transform, run_emb_transform from ._query import add_arguments_query, run_emb_query +from ._vectordb import add_arguments_vectordb, run_emb_vectordb from .._utils import CustomFormatter -__all__ = ["run_emb_fit", "run_emb_transform", "run_emb_query", "add_arguments_emb"] +__all__ = ["run_emb_fit", "run_emb_transform", "run_emb_query", "run_emb_vectordb", "add_arguments_emb"] def add_arguments_emb(parser: ap.ArgumentParser): @@ -34,3 +35,10 @@ def add_arguments_emb(parser: ap.ArgumentParser): add_arguments_query( subparsers.add_parser("query", description=desc, formatter_class=CustomFormatter) ) + + # vectordb + desc = """description: + Get summary statistics about a LanceDB vector database including datasets, cell counts, and embeddings.""" + add_arguments_vectordb( + subparsers.add_parser("vectordb", description=desc, formatter_class=CustomFormatter) + ) diff --git a/src/state/_cli/_emb/_vectordb.py b/src/state/_cli/_emb/_vectordb.py new file mode 100644 index 00000000..813dd4bf --- /dev/null +++ b/src/state/_cli/_emb/_vectordb.py @@ -0,0 +1,79 @@ +import argparse as ap +import logging +import json +import yaml +from typing import Dict, Any + + +def add_arguments_vectordb(parser: ap.ArgumentParser): + """Add arguments for state embedding vectordb CLI.""" + parser.add_argument("--lancedb", required=True, help="Path to existing LanceDB database") + parser.add_argument("--format", choices=["json", "yaml", "table"], default="table", + help="Output format for database summary") + + +def run_emb_vectordb(args: ap.ArgumentParser): + """ + Get summary statistics about a LanceDB vector database. + """ + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + from ...emb.vectordb import StateVectorDB + + # Connect to database + logger.info(f"Connecting to database at {args.lancedb}") + vector_db = StateVectorDB(args.lancedb) + + # Get database summary + summary = vector_db.get_database_summary() + + # Output in requested format + if args.format == "json": + print(json.dumps(summary, indent=2)) + elif args.format == "yaml": + print(yaml.dump(summary, default_flow_style=False)) + elif args.format == "table": + _print_table_summary(summary) + + logger.info("Database summary completed successfully!") + + +def _print_table_summary(summary: Dict[str, Any]) -> None: + """Print database summary in a nice table format.""" + if not summary["table_exists"]: + print("❌ Database table does not exist") + return + + if summary["num_cells"] == 0: + print("⚠️ Database table exists but is empty") + return + + # Print header + print("=" * 60) + print("📊 STATE VECTOR DATABASE SUMMARY") + print("=" * 60) + + # Basic stats + print(f"🔢 Total cells: {summary['num_cells']:,}") + print(f"📦 Total datasets: {summary['num_datasets']}") + print(f"🔑 Embedding keys: {summary['num_embedding_keys']}") + print(f"📐 Embedding dimension: {summary['embedding_dim']}") + print() + + # Datasets breakdown + if summary["datasets"]: + print("📂 DATASETS:") + for dataset in summary["datasets"]: + cell_count = summary["cells_per_dataset"].get(dataset, 0) + print(f" • {dataset}: {cell_count:,} cells") + print() + + # Embedding keys + if summary["embedding_keys"]: + print("🗝️ EMBEDDING KEYS:") + for key in summary["embedding_keys"]: + print(f" • {key}") + print() + + print("=" * 60) \ No newline at end of file diff --git a/src/state/emb/vectordb.py b/src/state/emb/vectordb.py index cf223e03..a57ce2ae 100644 --- a/src/state/emb/vectordb.py +++ b/src/state/emb/vectordb.py @@ -25,7 +25,7 @@ def create_or_update_table( embedding_key: str = "X_state", dataset_name: Optional[str] = None, batch_size: int = 1000, - ): + ) -> None: """Create or update the embeddings table. Args: @@ -210,4 +210,62 @@ def get_table_info(self): "num_rows": len(table), "columns": table.schema.names, "embedding_dim": len(table.to_pandas().iloc[0]['vector']) if len(table) > 0 else 0 - } \ No newline at end of file + } + + def get_database_summary(self) -> dict: + """Get comprehensive summary statistics about the database contents. + + Returns: + Dictionary containing database statistics including: + - num_cells: Total number of cells stored + - num_datasets: Number of unique datasets + - num_embedding_keys: Number of unique embedding keys + - datasets: List of dataset names + - embedding_keys: List of embedding key names + - cells_per_dataset: Dictionary mapping dataset to cell count + """ + if self.table_name not in self.db.table_names(): + return { + "num_cells": 0, + "num_datasets": 0, + "num_embedding_keys": 0, + "datasets": [], + "embedding_keys": [], + "cells_per_dataset": {}, + "table_exists": False + } + + table = self.db.open_table(self.table_name) + + # Get the full dataset to compute statistics + # For large tables, we might want to optimize this with SQL-like queries + df = table.to_pandas() + + if len(df) == 0: + return { + "num_cells": 0, + "num_datasets": 0, + "num_embedding_keys": 0, + "datasets": [], + "embedding_keys": [], + "cells_per_dataset": {}, + "table_exists": True + } + + # Calculate summary statistics + datasets = df['dataset'].unique().tolist() + embedding_keys = df['embedding_key'].unique().tolist() + cells_per_dataset = df['dataset'].value_counts().to_dict() + + summary = { + "num_cells": len(df), + "num_datasets": len(datasets), + "num_embedding_keys": len(embedding_keys), + "datasets": sorted(datasets), + "embedding_keys": sorted(embedding_keys), + "cells_per_dataset": cells_per_dataset, + "table_exists": True, + "embedding_dim": len(df.iloc[0]['vector']) if 'vector' in df.columns else 0 + } + + return summary \ No newline at end of file From 98f961ca83fadc901732fbb890d150459ef4ea01 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Wed, 23 Jul 2025 10:30:04 -0700 Subject: [PATCH 6/8] Refactor CLI argument names for consistency - Updated argument names in `add_arguments_infer` and `add_arguments_preprocess_infer` to use kebab-case (e.g., `--embed-key`, `--pert-col`, `--model-dir`, `--celltype-col`, `--batch-size`). - Modified help message for the `--seed` argument in `add_arguments_preprocess_infer` to remove default value mention. - Removed commented-out help handler in `add_arguments_train` for clarity. --- src/state/_cli/_tx/_infer.py | 10 +++++----- src/state/_cli/_tx/_preprocess_infer.py | 2 +- src/state/_cli/_tx/_train.py | 2 -- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 3babe7bc..5bf730f0 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -9,24 +9,24 @@ def add_arguments_infer(parser: argparse.ArgumentParser): help="Path to model checkpoint (.ckpt). If not provided, will use model_dir/checkpoints/final.ckpt", ) parser.add_argument("--adata", type=str, required=True, help="Path to input AnnData file (.h5ad)") - parser.add_argument("--embed_key", type=str, default=None, help="Key in adata.obsm for input features") + parser.add_argument("--embed-key", type=str, default=None, help="Key in adata.obsm for input features") parser.add_argument( - "--pert_col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels" + "--pert-col", type=str, default="drugname_drugconc", help="Column in adata.obs for perturbation labels" ) parser.add_argument("--output", type=str, default=None, help="Path to output AnnData file (.h5ad)") parser.add_argument( - "--model_dir", + "--model-dir", type=str, required=True, help="Path to the model_dir containing the config.yaml file and the pert_onehot_map.pt file that was saved during training.", ) parser.add_argument( - "--celltype_col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)" + "--celltype-col", type=str, default=None, help="Column in adata.obs for cell type labels (optional)" ) parser.add_argument( "--celltypes", type=str, default=None, help="Comma-separated list of cell types to include (optional)" ) - parser.add_argument("--batch_size", type=int, default=1000, help="Batch size for inference (default: 1000)") + parser.add_argument("--batch-size", type=int, default=1000, help="Batch size for inference") def run_tx_infer(args): diff --git a/src/state/_cli/_tx/_preprocess_infer.py b/src/state/_cli/_tx/_preprocess_infer.py index ed15b529..39985812 100644 --- a/src/state/_cli/_tx/_preprocess_infer.py +++ b/src/state/_cli/_tx/_preprocess_infer.py @@ -37,7 +37,7 @@ def add_arguments_preprocess_infer(parser: ap.ArgumentParser): "--seed", type=int, default=42, - help="Random seed for reproducibility (default: 42)", + help="Random seed for reproducibility", ) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index 71f7b4ea..c49bec54 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -6,8 +6,6 @@ def add_arguments_train(parser: ap.ArgumentParser): # Allow remaining args to be passed through to Hydra parser.add_argument("hydra_overrides", nargs="*", help="Hydra configuration overrides (e.g., data.batch_size=32)") - # Add custom help handler - #parser.add_argument("--help", "-h", action="store_true", help="Show configuration help with all parameters") def run_tx_train(cfg: DictConfig): From d91a8af5d7e776785c7a7a829809c25cd343e9a4 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Wed, 23 Jul 2025 10:42:12 -0700 Subject: [PATCH 7/8] Enhance CLI argument handling and logging configuration - Updated argument names in the CLI to use kebab-case for consistency (e.g., `--output-dir`, `--num-hvgs`, `--control-condition`, `--pert-col`). - Added `--log-level` argument to allow dynamic logging level configuration across commands. - Modified logging setup in various modules to utilize the new `log_level` argument for consistent logging behavior. - Updated `README.md` with corresponding changes to command usage examples. --- README.md | 20 ++++++++++---------- src/state/__main__.py | 11 ++++++----- src/state/_cli/_emb/_fit.py | 1 + src/state/_cli/_emb/_query.py | 2 +- src/state/_cli/_emb/_transform.py | 2 +- src/state/_cli/_emb/_vectordb.py | 2 +- src/state/_cli/_tx/__init__.py | 4 ++-- src/state/_cli/_tx/_infer.py | 2 +- src/state/_cli/_tx/_predict.py | 2 +- src/state/_cli/_tx/_preprocess_infer.py | 8 +++++--- src/state/_cli/_tx/_preprocess_train.py | 5 +++-- src/state/_cli/_tx/_train.py | 3 ++- 12 files changed, 34 insertions(+), 28 deletions(-) diff --git a/README.md b/README.md index eafcb2dd..32a8523d 100644 --- a/README.md +++ b/README.md @@ -87,17 +87,17 @@ The cell lines and perturbations specified in the TOML should match the values a you can use the `tx predict` command: ```bash -state tx predict --output_dir $HOME/state/test/ --checkpoint final.ckpt +state tx predict --output-dir $HOME/state/test/ --checkpoint final.ckpt ``` -It will look in the `output_dir` above, for a `checkpoints` folder. +It will look in the `output-dir` above, for a `checkpoints` folder. If you instead want to use a trained checkpoint for inference (e.g. on data not specified) in the TOML file: ```bash -state tx infer --output $HOME/state/test/ --output_dir /path/to/model/ --checkpoint /path/to/model/final.ckpt --adata /path/to/anndata/processed.h5 --pert_col gene --embed_key X_hvg +state tx infer --output $HOME/state/test/ --output-dir /path/to/model/ --checkpoint /path/to/model/final.ckpt --adata /path/to/anndata/processed.h5 --pert-col gene --embed-key X_hvg ``` Here, `/path/to/model/` is the folder downloaded from [HuggingFace](https://huggingface.co/arcinstitute). @@ -108,13 +108,13 @@ State provides two preprocessing commands to prepare data for training and infer #### Training Data Preprocessing -Use `preprocess_train` to normalize, log-transform, and select highly variable genes from your training data: +Use `preprocess-train` to normalize, log-transform, and select highly variable genes from your training data: ```bash -state tx preprocess_train \ +state tx preprocess-train \ --adata /path/to/raw_data.h5ad \ --output /path/to/preprocessed_training_data.h5ad \ - --num_hvgs 2000 + --num-hvgs 2000 ``` This command: @@ -125,14 +125,14 @@ This command: #### Inference Data Preprocessing -Use `preprocess_infer` to create a "control template" for model inference: +Use `preprocess-infer` to create a "control template" for model inference: ```bash -state tx preprocess_infer \ +state tx preprocess-infer \ --adata /path/to/real_data.h5ad \ --output /path/to/control_template.h5ad \ - --control_condition "DMSO" \ - --pert_col "treatment" \ + --control-condition "DMSO" \ + --pert-col "treatment" \ --seed 42 ``` diff --git a/src/state/__main__.py b/src/state/__main__.py index b6f3f9c3..5392ea88 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -26,6 +26,7 @@ def get_args() -> tuple[ap.Namespace, list[str]]: Use these commands to train models, compute embeddings, and run inference. Run `state --help` for details on each command.""" parser = ap.ArgumentParser(description=desc, formatter_class=CustomFormatter) + parser.add_argument("--log-level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level") subparsers = parser.add_subparsers(required=True, dest="command") # emb @@ -129,19 +130,19 @@ def main(): else: # Load Hydra config with overrides for sets training cfg = load_hydra_config("tx", args.hydra_overrides) - run_tx_train(cfg) + run_tx_train(cfg, args) case "predict": # For now, predict uses argparse and not hydra run_tx_predict(args) case "infer": # Run inference using argparse, similar to predict run_tx_infer(args) - case "preprocess_train": + case "preprocess-train": # Run preprocessing using argparse - run_tx_preprocess_train(args.adata, args.output, args.num_hvgs) - case "preprocess_infer": + run_tx_preprocess_train(args.adata, args.output, args.num_hvgs, args.log_level) + case "preprocess-infer": # Run inference preprocessing using argparse - run_tx_preprocess_infer(args.adata, args.output, args.control_condition, args.pert_col, args.seed) + run_tx_preprocess_infer(args.adata, args.output, args.control_condition, args.pert_col, args.seed, args.log_level) if __name__ == "__main__": diff --git a/src/state/_cli/_emb/_fit.py b/src/state/_cli/_emb/_fit.py index fddbb032..4e010458 100644 --- a/src/state/_cli/_emb/_fit.py +++ b/src/state/_cli/_emb/_fit.py @@ -21,6 +21,7 @@ def run_emb_fit(cfg, args): from ...emb.train.trainer import main as trainer_main + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) log = logging.getLogger(__name__) # Load the base configuration diff --git a/src/state/_cli/_emb/_query.py b/src/state/_cli/_emb/_query.py index ad66a3d0..39c26849 100644 --- a/src/state/_cli/_emb/_query.py +++ b/src/state/_cli/_emb/_query.py @@ -24,7 +24,7 @@ def run_emb_query(args: ap.ArgumentParser): """ Query a LanceDB database for similar cells. """ - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) from ...emb.vectordb import StateVectorDB diff --git a/src/state/_cli/_emb/_transform.py b/src/state/_cli/_emb/_transform.py index bf6aa25f..b04e6014 100644 --- a/src/state/_cli/_emb/_transform.py +++ b/src/state/_cli/_emb/_transform.py @@ -26,7 +26,7 @@ def run_emb_transform(args: ap.ArgumentParser): import torch from omegaconf import OmegaConf - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) from ...emb.inference import Inference diff --git a/src/state/_cli/_emb/_vectordb.py b/src/state/_cli/_emb/_vectordb.py index 813dd4bf..af159ef2 100644 --- a/src/state/_cli/_emb/_vectordb.py +++ b/src/state/_cli/_emb/_vectordb.py @@ -16,7 +16,7 @@ def run_emb_vectordb(args: ap.ArgumentParser): """ Get summary statistics about a LanceDB vector database. """ - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) from ...emb.vectordb import StateVectorDB diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index 7fb68f22..f901958c 100644 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -40,10 +40,10 @@ def add_arguments_tx(parser: ap.ArgumentParser): # Preprocess: train desc = """description: Preprocess a dataset for training.""" - add_arguments_preprocess_train(subparsers.add_parser("preprocess_train", description=desc, formatter_class=CustomFormatter)) + add_arguments_preprocess_train(subparsers.add_parser("preprocess-train", description=desc, formatter_class=CustomFormatter)) # Preprocess: infer desc = """description: Preprocess a dataset for inference.""" - add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer", description=desc, formatter_class=CustomFormatter)) + add_arguments_preprocess_infer(subparsers.add_parser("preprocess-infer", description=desc, formatter_class=CustomFormatter)) diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py index 5bf730f0..4bb185a1 100644 --- a/src/state/_cli/_tx/_infer.py +++ b/src/state/_cli/_tx/_infer.py @@ -42,7 +42,7 @@ def run_tx_infer(args): from ...tx.models.state_transition import StateTransitionPerturbationModel - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) def load_config(cfg_path: str) -> dict: diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 2115a815..7bdd1129 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -59,7 +59,7 @@ def run_tx_predict(args: ap.ArgumentParser): from cell_load.data_modules import PerturbationDataModule from tqdm import tqdm - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) torch.multiprocessing.set_sharing_strategy("file_system") diff --git a/src/state/_cli/_tx/_preprocess_infer.py b/src/state/_cli/_tx/_preprocess_infer.py index 39985812..9f0aa44e 100644 --- a/src/state/_cli/_tx/_preprocess_infer.py +++ b/src/state/_cli/_tx/_preprocess_infer.py @@ -22,13 +22,13 @@ def add_arguments_preprocess_infer(parser: ap.ArgumentParser): help="Path to output preprocessed AnnData file (.h5ad)", ) parser.add_argument( - "--control_condition", + "--control-condition", type=str, required=True, help="Control condition identifier (e.g., \"[('DMSO_TF', 0.0, 'uM')]\")", ) parser.add_argument( - "--pert_col", + "--pert-col", type=str, required=True, help="Column name containing perturbation information (e.g., 'drugname_drugconc')", @@ -46,7 +46,8 @@ def run_tx_preprocess_infer( output_path: str, control_condition: str, pert_col: str, - seed: int = 42 + seed: int = 42, + log_level: str = "INFO" ): """ Preprocess inference data by replacing perturbed cells with control expression. @@ -62,6 +63,7 @@ def run_tx_preprocess_infer( pert_col: Column name containing perturbation information seed: Random seed for reproducibility """ + logging.basicConfig(level=getattr(logging, log_level, logging.INFO)) logger.info(f"Loading AnnData from {adata_path}") adata = ad.read_h5ad(adata_path) diff --git a/src/state/_cli/_tx/_preprocess_train.py b/src/state/_cli/_tx/_preprocess_train.py index f2c0b304..987e516f 100644 --- a/src/state/_cli/_tx/_preprocess_train.py +++ b/src/state/_cli/_tx/_preprocess_train.py @@ -22,14 +22,14 @@ def add_arguments_preprocess_train(parser: ap.ArgumentParser): help="Path to output preprocessed AnnData file (.h5ad)", ) parser.add_argument( - "--num_hvgs", + "--num-hvgs", type=int, required=True, help="Number of highly variable genes to select", ) -def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): +def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int, log_level: str): """ Preprocess training data by normalizing, log-transforming, and selecting highly variable genes. @@ -38,6 +38,7 @@ def run_tx_preprocess_train(adata_path: str, output_path: str, num_hvgs: int): output_path: Path to save preprocessed AnnData file num_hvgs: Number of highly variable genes to select """ + logging.basicConfig(level=getattr(logging, log_level, logging.INFO)) logger.info(f"Loading AnnData from {adata_path}") adata = ad.read_h5ad(adata_path) diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py index c49bec54..0ae05f5f 100644 --- a/src/state/_cli/_tx/_train.py +++ b/src/state/_cli/_tx/_train.py @@ -8,7 +8,7 @@ def add_arguments_train(parser: ap.ArgumentParser): parser.add_argument("hydra_overrides", nargs="*", help="Hydra configuration overrides (e.g., data.batch_size=32)") -def run_tx_train(cfg: DictConfig): +def run_tx_train(cfg: DictConfig, args: ap.ArgumentParser): import json import logging import os @@ -27,6 +27,7 @@ def run_tx_train(cfg: DictConfig): from ...tx.callbacks import BatchSpeedMonitorCallback from ...tx.utils import get_checkpoint_callbacks, get_lightning_module, get_loggers + logging.basicConfig(level=getattr(logging, args.log_level, logging.INFO)) logger = logging.getLogger(__name__) torch.set_float32_matmul_precision("medium") From 5a2d03bfb4ad4d2f5afe6adf0b1313e8225ada57 Mon Sep 17 00:00:00 2001 From: nick-youngblut Date: Wed, 23 Jul 2025 10:44:16 -0700 Subject: [PATCH 8/8] Update README.md for improved command usage formatting - Reformatted command examples in the README to use multi-line syntax for better readability. - Ensured consistency in the presentation of CLI commands for `tx predict` and `tx infer` sections. --- README.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 32a8523d..e6a37327 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,9 @@ The cell lines and perturbations specified in the TOML should match the values a you can use the `tx predict` command: ```bash -state tx predict --output-dir $HOME/state/test/ --checkpoint final.ckpt +state tx predict \ + --output-dir $HOME/state/test/ \ + --checkpoint final.ckpt ``` It will look in the `output-dir` above, for a `checkpoints` folder. @@ -97,7 +99,13 @@ in the TOML file: ```bash -state tx infer --output $HOME/state/test/ --output-dir /path/to/model/ --checkpoint /path/to/model/final.ckpt --adata /path/to/anndata/processed.h5 --pert-col gene --embed-key X_hvg +state tx infer \ + --output $HOME/state/test/ \ + --output-dir /path/to/model/ \ + --checkpoint /path/to/model/final.ckpt \ + --adata /path/to/anndata/processed.h5 \ + --pert-col gene \ + --embed-key X_hvg ``` Here, `/path/to/model/` is the folder downloaded from [HuggingFace](https://huggingface.co/arcinstitute).