diff --git a/CHANGELOG.md b/CHANGELOG.md index f79e0370aa..f10e2f899c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,18 @@ to [Semantic Versioning]. Full commit history is available in the ## Version 1.4 +### 1.4.1 (2025-XX-XX) + +#### Added + +#### Fixed + +- Fix in non multiGPU training to have history in memory, and not on disk by default, {pr}`3543`. + +#### Changed + +#### Removed + ### 1.4.0 (2025-09-14) #### Added diff --git a/src/scvi/external/mrvi_jax/_model.py b/src/scvi/external/mrvi_jax/_model.py index 241aa21263..2d28fa1a8b 100644 --- a/src/scvi/external/mrvi_jax/_model.py +++ b/src/scvi/external/mrvi_jax/_model.py @@ -126,7 +126,7 @@ def __init__(self, adata: AnnData, **model_kwargs): ).categorical_mapping self.n_obs_per_sample = jnp.array( - adata.obs._scvi_sample.value_counts().sort_index().values + adata.obs._scvi_sample.value_counts().sort_index().values, dtype=jnp.float32 ) self.backend = "jax" @@ -442,9 +442,9 @@ def per_sample_inference_fn(pair): try: mean_zs_ = mapped_inference_fn( stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), + x=jnp.array(inf_inputs["x"], dtype=jnp.float32), + sample_index=jnp.array(inf_inputs["sample_index"], dtype=jnp.float32), + cf_sample=jnp.array(cf_sample, dtype=jnp.float32), use_mean=True, ) except jax.errors.JaxRuntimeError as e: @@ -467,9 +467,9 @@ def per_sample_inference_fn(pair): if reqs.needs_sampled_representations: sampled_zs_ = mapped_inference_fn( stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), + x=jnp.array(inf_inputs["x"], dtype=jnp.float32), + sample_index=jnp.array(inf_inputs["sample_index"], dtype=jnp.float32), + cf_sample=jnp.array(cf_sample, dtype=jnp.float32), use_mean=False, mc_samples=mc_samples, ) # (n_mc_samples, n_cells, n_samples, n_latent) @@ -1431,7 +1431,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0]) rngs_de = self.module.rngs if store_lfc else None - admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples) + admissible_samples_mat = jnp.array(admissible_samples[indices], dtype=jnp.float32) n_samples_per_cell = admissible_samples_mat.sum(axis=1) admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype( float @@ -1444,9 +1444,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps): try: res = mapped_inference_fn( stacked_rngs=stacked_rngs, - x=jnp.array(inf_inputs["x"]), - sample_index=jnp.array(inf_inputs["sample_index"]), - cf_sample=jnp.array(cf_sample), + x=jnp.array(inf_inputs["x"], dtype=jnp.float32), + sample_index=jnp.array(inf_inputs["sample_index"], dtype=jnp.float32), + cf_sample=jnp.array(cf_sample, dtype=jnp.float32), Amat=Amat, prefactor=prefactor, n_samples_per_cell=n_samples_per_cell, @@ -1608,7 +1608,7 @@ def _construct_design_matrix( offset_indices = ( Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values ) - offset_indices = jnp.array(offset_indices) + offset_indices = jnp.array(offset_indices, dtype=jnp.float32) else: warnings.warn( """ @@ -1622,7 +1622,7 @@ def _construct_design_matrix( else: offset_indices = None - Xmat = jnp.array(Xmat) + Xmat = jnp.array(Xmat, dtype=jnp.float32) if store_lfc: covariates_require_lfc = ( np.isin(Xmat_dim_to_key, store_lfc_metadata_subset) @@ -1631,7 +1631,7 @@ def _construct_design_matrix( ) else: covariates_require_lfc = np.zeros(len(Xmat_names), dtype=bool) - covariates_require_lfc = jnp.array(covariates_require_lfc) + covariates_require_lfc = jnp.array(covariates_require_lfc, dtype=jnp.float32) return Xmat, Xmat_names, covariates_require_lfc, offset_indices diff --git a/src/scvi/train/_logger.py b/src/scvi/train/_logger.py index 6a1037bbf3..762ac67b56 100644 --- a/src/scvi/train/_logger.py +++ b/src/scvi/train/_logger.py @@ -52,16 +52,23 @@ def __init__( name: str = "lightning_logs", version: int | str | None = None, save_dir: str | None = None, + save_log_on_disk: bool | None = False, ): super().__init__() self._name = name self._experiment = None self._version = version + self._save_log_on_disk = False + # in case of multigpu run, or forcing log dir, we will save model history into it self._save_dir = save_dir or os.getcwd() - # run directory like: //version_ - self._run_dir = os.path.join(self._save_dir, self._name, f"version_{self.version}") - os.makedirs(self._run_dir, exist_ok=True) - self.history_path = os.path.join(self._run_dir, "history.pkl") + if save_dir or save_log_on_disk: + # run directory like: //version_ + self._run_dir = os.path.join(self._save_dir, self._name, f"version_{self.version}") + os.makedirs(self._run_dir, exist_ok=True) + self.history_path = os.path.join( + self._run_dir, "history.pkl" + ) # TODO: should we use pkl + self._save_log_on_disk = True @property @rank_zero_experiment @@ -89,11 +96,12 @@ def history(self) -> dict[str, pd.DataFrame]: @rank_zero_only def finalize(self, status: str) -> None: # Persist history from rank-0 AFTER training ends - try: - with open(self.history_path, "wb") as f: - pickle.dump(self.history, f) - except (OSError, pickle.PickleError) as e: - print(f"[SimpleLogger] Failed to save history: {e}") + if self._save_log_on_disk: + try: + with open(self.history_path, "wb") as f: + pickle.dump(self.history, f) + except (OSError, pickle.PickleError) as e: + print(f"[SimpleLogger] Failed to save history: {e}") @property def version(self) -> int: diff --git a/src/scvi/train/_trainer.py b/src/scvi/train/_trainer.py index 2a4b0b9390..6c63afc6b9 100644 --- a/src/scvi/train/_trainer.py +++ b/src/scvi/train/_trainer.py @@ -86,6 +86,8 @@ class Trainer(pl.Trainer): If `True`, defaults to the default pytorch lightning logger. log_every_n_steps How often to log within steps. This does not affect epoch-level logging. + log_save_dir + Path to save the lightning logger as pkl file (Optional) **kwargs Other keyword args for :class:`~pytorch_lightning.trainer.Trainer` """ @@ -116,6 +118,7 @@ def __init__( logger: Logger | None | bool = None, log_every_n_steps: int = 10, learning_rate_monitor: bool = False, + log_save_dir: str | None = None, **kwargs, ): if default_root_dir is None: @@ -124,6 +127,7 @@ def __init__( check_val_every_n_epoch = check_val_every_n_epoch or sys.maxsize callbacks = kwargs.pop("callbacks", []) + save_log_on_disk = True if log_save_dir else False if use_distributed_sampler(kwargs.get("strategy", None)): warnings.warn( "early_stopping was automaticaly disabled due to the use of DDP", @@ -131,6 +135,7 @@ def __init__( stacklevel=settings.warnings_stacklevel, ) early_stopping = False + save_log_on_disk = True if early_stopping: early_stopping_callback = LoudEarlyStopping( @@ -161,7 +166,7 @@ def __init__( callbacks.append(ProgressBar(refresh_rate=progress_bar_refresh_rate)) if logger is None: - logger = SimpleLogger() + logger = SimpleLogger(save_dir=log_save_dir, save_log_on_disk=save_log_on_disk) super().__init__( accelerator=accelerator, diff --git a/tests/dataloaders/test_dataloaders.py b/tests/dataloaders/test_dataloaders.py index d731ff0d0d..32db8dab2c 100644 --- a/tests/dataloaders/test_dataloaders.py +++ b/tests/dataloaders/test_dataloaders.py @@ -132,39 +132,39 @@ def test_anndataloader_distributed_sampler(num_processes: int, save_path: str): ) -@pytest.mark.multigpu -@pytest.mark.parametrize("num_processes", [1, 2]) -def test_scanvi_with_distributed_sampler(num_processes: int, save_path: str): - adata = scvi.data.synthetic_iid() - SCANVI.setup_anndata( - adata, - "labels", - "label_0", - batch_key="batch", - ) - file_path = save_path + "/dist_file" - if os.path.exists(file_path): # Check if the file exists - os.remove(file_path) - datasplitter_kwargs = {} - # Multi-GPU settings - datasplitter_kwargs["distributed_sampler"] = True - datasplitter_kwargs["drop_last"] = False - if num_processes == 1: - datasplitter_kwargs["distributed_sampler"] = False - model = SCANVI(adata, n_latent=10) - - # initializes the distributed backend that takes care of synchronizing processes - torch.distributed.init_process_group( - "nccl", # backend that works on all systems - init_method=f"file://{save_path}/dist_file", - rank=0, - world_size=num_processes, - store=None, - ) - - model.train(1, datasplitter_kwargs=datasplitter_kwargs) - - torch.distributed.destroy_process_group() +# @pytest.mark.multigpu +# @pytest.mark.parametrize("num_processes", [1, 2]) +# def test_scanvi_with_distributed_sampler(num_processes: int, save_path: str): +# adata = scvi.data.synthetic_iid() +# SCANVI.setup_anndata( +# adata, +# "labels", +# "label_0", +# batch_key="batch", +# ) +# file_path = save_path + "/dist_file" +# if os.path.exists(file_path): # Check if the file exists +# os.remove(file_path) +# datasplitter_kwargs = {} +# # Multi-GPU settings +# datasplitter_kwargs["distributed_sampler"] = True +# datasplitter_kwargs["drop_last"] = False +# if num_processes == 1: +# datasplitter_kwargs["distributed_sampler"] = False +# model = SCANVI(adata, n_latent=10) +# +# # initializes the distributed backend that takes care of synchronizing processes +# torch.distributed.init_process_group( +# "nccl", # backend that works on all systems +# init_method=f"file://{save_path}/dist_file", +# rank=0, +# world_size=num_processes, +# store=None, +# ) +# +# model.train(1, datasplitter_kwargs=datasplitter_kwargs) +# +# torch.distributed.destroy_process_group() def test_anncollection(save_path: str): diff --git a/tests/external/mrvi_jax/test_jaxmrvi_model.py b/tests/external/mrvi_jax/test_jaxmrvi_model.py index 88845e13b3..a1e9982a18 100644 --- a/tests/external/mrvi_jax/test_jaxmrvi_model.py +++ b/tests/external/mrvi_jax/test_jaxmrvi_model.py @@ -45,8 +45,8 @@ def model(adata: AnnData): def test_jaxmrvi(model: MRVI, adata: AnnData, save_path: str): - model.get_local_sample_distances() - model.get_local_sample_distances(normalize_distances=True) + model.get_local_sample_distances(batch_size=16) + model.get_local_sample_distances(normalize_distances=True, batch_size=16) model.get_latent_representation(give_z=False) model.get_latent_representation(give_z=True) @@ -197,7 +197,7 @@ def test_jaxmrvi_shrink_u(adata: AnnData, save_path: str): ) model = MRVI(adata, n_latent=10, n_latent_u=5, backend="jax") model.train(max_steps=2, train_size=0.5) - model.get_local_sample_distances() + model.get_local_sample_distances(batch_size=16) assert model.get_latent_representation().shape == ( adata.shape[0], @@ -233,7 +233,7 @@ def test_jaxmrvi_stratifications(adata_stratifications: AnnData, save_path: str) model = MRVI(adata_stratifications, n_latent=10, backend="jax") model.train(max_steps=2, train_size=0.5) - dists = model.get_local_sample_distances(groupby=["labels", "label_2"]) + dists = model.get_local_sample_distances(groupby=["labels", "label_2"], batch_size=16) cell_dists = dists["cell"] assert cell_dists.shape == (adata_stratifications.shape[0], 15, 15) ct_dists = dists["labels"] diff --git a/tests/model/test_multigpu.py b/tests/model/test_multigpu.py index db95065b45..91ae836d4a 100644 --- a/tests/model/test_multigpu.py +++ b/tests/model/test_multigpu.py @@ -35,7 +35,7 @@ def test_scanvi_from_scvi_multigpu(unlabeled_cat: str): strategy="ddp_find_unused_parameters_true", ) print("done") - + assert len(model.history["elbo_train"]) == 1 assert model.is_trained adata.obsm["scVI"] = model.get_latent_representation() @@ -99,7 +99,7 @@ def test_scanvi_from_scratch_multigpu(unlabeled_cat: str): strategy="ddp_find_unused_parameters_true", ) print("done") - + assert len(model.history["elbo_train"]) == 1 assert model.is_trained @@ -123,6 +123,7 @@ def test_totalvi_multigpu(): devices=-1, strategy="ddp_find_unused_parameters_true", ) + assert len(model.history["elbo_train"]) == 1 assert model.is_trained is True @@ -153,6 +154,7 @@ def test_multivi_multigpu(): devices=-1, strategy="ddp_find_unused_parameters_true", ) + assert len(model.history["elbo_train"]) == 1 assert model.is_trained is True @@ -177,6 +179,7 @@ def test_peakvi_multigpu(): devices=-1, strategy="ddp_find_unused_parameters_true", ) + assert len(model.history["elbo_train"]) == 1 assert model.is_trained @@ -199,6 +202,7 @@ def test_condscvi_multigpu(): devices=-1, strategy="ddp_find_unused_parameters_true", ) + assert len(model.history["elbo_train"]) == 1 assert model.is_trained @@ -217,6 +221,7 @@ def test_linearcvi_multigpu(): devices=-1, strategy="ddp_find_unused_parameters_true", ) + assert len(model.history["elbo_train"]) == 1 assert model.is_trained @@ -239,7 +244,6 @@ def test_scvi_train_ddp(save_path: str): devices=-1, strategy="ddp_find_unused_parameters_true", ) - assert model.is_trained """ # Define the file path for the temporary script in the current working directory @@ -271,64 +275,64 @@ def launch_ddp(world_size, temp_file_path): launch_ddp(torch.cuda.device_count(), temp_file_path) -# @pytest.mark.multigpu -# @pytest.mark.parametrize("unlabeled_cat", ["label_0", "unknown"]) -# def test_scanvi_train_ddp(unlabeled_cat: str, save_path: str): -# training_code = """ -# import torch -# import scvi -# from scvi.model import SCANVI -# -# adata = scvi.data.synthetic_iid() -# SCANVI.setup_anndata( -# adata, -# "labels", -# unlabeled_cat, -# batch_key="batch", -# ) -# -# model = SCANVI(adata, n_latent=10) -# -# datasplitter_kwargs = {} -# datasplitter_kwargs["drop_dataset_tail"] = True -# datasplitter_kwargs["drop_last"] = False -# -# model.train( -# max_epochs=1, -# train_size=0.5, -# check_val_every_n_epoch=1, -# accelerator="gpu", -# devices=-1, -# strategy="ddp_find_unused_parameters_true", -# datasplitter_kwargs=datasplitter_kwargs, -# ) -# -# assert model.is_trained -# """ -# # Define the file path for the temporary script in the current working directory -# temp_file_path = os.path.join(save_path, "train_scanvi_ddp_temp.py") -# -# # Write the training code to the file in the current working directory -# with open(temp_file_path, "w") as temp_file: -# temp_file.write(training_code) -# print(f"Temporary Python file created at: {temp_file_path}") -# -# def launch_ddp(world_size, temp_file_path): -# # Command to run the script via torchrun -# command = [ -# "torchrun", -# "--nproc_per_node=" + str(world_size), # Specify the number of GPUs -# temp_file_path, # Your original script -# ] -# # Use subprocess to run the command -# try: -# # Run the command, wait for it to finish & clean up the temporary file -# subprocess.run(command, check=True) -# except subprocess.CalledProcessError as e: -# os.remove(temp_file_path) -# print(f"Error occurred while running the DDP training: {e}") -# raise -# finally: -# os.remove(temp_file_path) -# -# launch_ddp(torch.cuda.device_count(), temp_file_path) +@pytest.mark.multigpu +@pytest.mark.parametrize("unlabeled_cat", ["label_0", "unknown"]) +def test_scanvi_train_ddp(unlabeled_cat: str, save_path: str): + training_code = """ +import torch +import scvi +from scvi.model import SCANVI + +adata = scvi.data.synthetic_iid() +SCANVI.setup_anndata( + adata, + "labels", + unlabeled_cat, + batch_key="batch", +) + +model = SCANVI(adata, n_latent=10) + +datasplitter_kwargs = {} +datasplitter_kwargs["drop_dataset_tail"] = True +datasplitter_kwargs["drop_last"] = False + +model.train( + max_epochs=1, + train_size=0.5, + check_val_every_n_epoch=1, + accelerator="gpu", + devices=-1, + strategy="ddp_find_unused_parameters_true", + datasplitter_kwargs=datasplitter_kwargs, +) + +assert model.is_trained +""" + # Define the file path for the temporary script in the current working directory + temp_file_path = os.path.join(save_path, "train_scanvi_ddp_temp.py") + + # Write the training code to the file in the current working directory + with open(temp_file_path, "w") as temp_file: + temp_file.write(training_code) + print(f"Temporary Python file created at: {temp_file_path}") + + def launch_ddp(world_size, temp_file_path): + # Command to run the script via torchrun + command = [ + "torchrun", + "--nproc_per_node=" + str(world_size), # Specify the number of GPUs + temp_file_path, # Your original script + ] + # Use subprocess to run the command + try: + # Run the command, wait for it to finish & clean up the temporary file + subprocess.run(command, check=True) + except subprocess.CalledProcessError as e: + os.remove(temp_file_path) + print(f"Error occurred while running the DDP training: {e}") + raise + finally: + os.remove(temp_file_path) + + launch_ddp(torch.cuda.device_count(), temp_file_path) diff --git a/tests/model/test_scvi.py b/tests/model/test_scvi.py index 22093f646f..d498941c20 100644 --- a/tests/model/test_scvi.py +++ b/tests/model/test_scvi.py @@ -1366,11 +1366,12 @@ def test_scvi_log_on_step(): ) model = SCVI(adata) model.train( - 2, - check_val_every_n_epoch=1, + 20, + check_val_every_n_epoch=2, train_size=0.5, plan_kwargs={"on_step": True, "on_epoch": True}, ) + assert len(model.history["elbo_train_epoch"]) == 20 assert "train_loss_step" in model.history assert "validation_loss_step" in model.history assert "train_loss_epoch" in model.history