Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ to [Semantic Versioning]. Full commit history is available in the

## Version 1.4

### 1.4.1 (2025-XX-XX)

#### Added

#### Fixed

- Fix in non multiGPU training to have history in memory, and not on disk by default, {pr}`3543`.

#### Changed

#### Removed

### 1.4.0 (2025-09-14)

#### Added
Expand Down
28 changes: 14 additions & 14 deletions src/scvi/external/mrvi_jax/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(self, adata: AnnData, **model_kwargs):
).categorical_mapping

self.n_obs_per_sample = jnp.array(
adata.obs._scvi_sample.value_counts().sort_index().values
adata.obs._scvi_sample.value_counts().sort_index().values, dtype=jnp.float32
)
self.backend = "jax"

Expand Down Expand Up @@ -442,9 +442,9 @@ def per_sample_inference_fn(pair):
try:
mean_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
x=jnp.array(inf_inputs["x"], dtype=jnp.float32),
sample_index=jnp.array(inf_inputs["sample_index"], dtype=jnp.float32),
cf_sample=jnp.array(cf_sample, dtype=jnp.float32),
use_mean=True,
)
except jax.errors.JaxRuntimeError as e:
Expand All @@ -467,9 +467,9 @@ def per_sample_inference_fn(pair):
if reqs.needs_sampled_representations:
sampled_zs_ = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
x=jnp.array(inf_inputs["x"], dtype=jnp.float32),
sample_index=jnp.array(inf_inputs["sample_index"], dtype=jnp.float32),
cf_sample=jnp.array(cf_sample, dtype=jnp.float32),
use_mean=False,
mc_samples=mc_samples,
) # (n_mc_samples, n_cells, n_samples, n_latent)
Expand Down Expand Up @@ -1431,7 +1431,7 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps):
stacked_rngs = self._generate_stacked_rngs(cf_sample.shape[0])

rngs_de = self.module.rngs if store_lfc else None
admissible_samples_mat = jnp.array(admissible_samples[indices]) # (n_cells, n_samples)
admissible_samples_mat = jnp.array(admissible_samples[indices], dtype=jnp.float32)
n_samples_per_cell = admissible_samples_mat.sum(axis=1)
admissible_samples_dmat = jax.vmap(jnp.diag)(admissible_samples_mat).astype(
float
Expand All @@ -1444,9 +1444,9 @@ def h_inference_fn(extra_eps, batch_index_cf, batch_offset_eps):
try:
res = mapped_inference_fn(
stacked_rngs=stacked_rngs,
x=jnp.array(inf_inputs["x"]),
sample_index=jnp.array(inf_inputs["sample_index"]),
cf_sample=jnp.array(cf_sample),
x=jnp.array(inf_inputs["x"], dtype=jnp.float32),
sample_index=jnp.array(inf_inputs["sample_index"], dtype=jnp.float32),
cf_sample=jnp.array(cf_sample, dtype=jnp.float32),
Amat=Amat,
prefactor=prefactor,
n_samples_per_cell=n_samples_per_cell,
Expand Down Expand Up @@ -1608,7 +1608,7 @@ def _construct_design_matrix(
offset_indices = (
Series(np.arange(len(Xmat_names)), index=Xmat_names).loc[cov_names].values
)
offset_indices = jnp.array(offset_indices)
offset_indices = jnp.array(offset_indices, dtype=jnp.float32)
else:
warnings.warn(
"""
Expand All @@ -1622,7 +1622,7 @@ def _construct_design_matrix(
else:
offset_indices = None

Xmat = jnp.array(Xmat)
Xmat = jnp.array(Xmat, dtype=jnp.float32)
if store_lfc:
covariates_require_lfc = (
np.isin(Xmat_dim_to_key, store_lfc_metadata_subset)
Expand All @@ -1631,7 +1631,7 @@ def _construct_design_matrix(
)
else:
covariates_require_lfc = np.zeros(len(Xmat_names), dtype=bool)
covariates_require_lfc = jnp.array(covariates_require_lfc)
covariates_require_lfc = jnp.array(covariates_require_lfc, dtype=jnp.float32)

return Xmat, Xmat_names, covariates_require_lfc, offset_indices

Expand Down
26 changes: 17 additions & 9 deletions src/scvi/train/_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,23 @@ def __init__(
name: str = "lightning_logs",
version: int | str | None = None,
save_dir: str | None = None,
save_log_on_disk: bool | None = False,
):
super().__init__()
self._name = name
self._experiment = None
self._version = version
self._save_log_on_disk = False
# in case of multigpu run, or forcing log dir, we will save model history into it
self._save_dir = save_dir or os.getcwd()
# run directory like: <save_dir>/<name>/version_<N>
self._run_dir = os.path.join(self._save_dir, self._name, f"version_{self.version}")
os.makedirs(self._run_dir, exist_ok=True)
self.history_path = os.path.join(self._run_dir, "history.pkl")
if save_dir or save_log_on_disk:
# run directory like: <save_dir>/<name>/version_<N>
self._run_dir = os.path.join(self._save_dir, self._name, f"version_{self.version}")
os.makedirs(self._run_dir, exist_ok=True)
self.history_path = os.path.join(
self._run_dir, "history.pkl"
) # TODO: should we use pkl
self._save_log_on_disk = True

@property
@rank_zero_experiment
Expand Down Expand Up @@ -89,11 +96,12 @@ def history(self) -> dict[str, pd.DataFrame]:
@rank_zero_only
def finalize(self, status: str) -> None:
# Persist history from rank-0 AFTER training ends
try:
with open(self.history_path, "wb") as f:
pickle.dump(self.history, f)
except (OSError, pickle.PickleError) as e:
print(f"[SimpleLogger] Failed to save history: {e}")
if self._save_log_on_disk:
try:
with open(self.history_path, "wb") as f:
pickle.dump(self.history, f)
except (OSError, pickle.PickleError) as e:
print(f"[SimpleLogger] Failed to save history: {e}")

@property
def version(self) -> int:
Expand Down
7 changes: 6 additions & 1 deletion src/scvi/train/_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class Trainer(pl.Trainer):
If `True`, defaults to the default pytorch lightning logger.
log_every_n_steps
How often to log within steps. This does not affect epoch-level logging.
log_save_dir
Path to save the lightning logger as pkl file (Optional)
**kwargs
Other keyword args for :class:`~pytorch_lightning.trainer.Trainer`
"""
Expand Down Expand Up @@ -116,6 +118,7 @@ def __init__(
logger: Logger | None | bool = None,
log_every_n_steps: int = 10,
learning_rate_monitor: bool = False,
log_save_dir: str | None = None,
**kwargs,
):
if default_root_dir is None:
Expand All @@ -124,13 +127,15 @@ def __init__(
check_val_every_n_epoch = check_val_every_n_epoch or sys.maxsize
callbacks = kwargs.pop("callbacks", [])

save_log_on_disk = True if log_save_dir else False
if use_distributed_sampler(kwargs.get("strategy", None)):
warnings.warn(
"early_stopping was automaticaly disabled due to the use of DDP",
UserWarning,
stacklevel=settings.warnings_stacklevel,
)
early_stopping = False
save_log_on_disk = True

if early_stopping:
early_stopping_callback = LoudEarlyStopping(
Expand Down Expand Up @@ -161,7 +166,7 @@ def __init__(
callbacks.append(ProgressBar(refresh_rate=progress_bar_refresh_rate))

if logger is None:
logger = SimpleLogger()
logger = SimpleLogger(save_dir=log_save_dir, save_log_on_disk=save_log_on_disk)

super().__init__(
accelerator=accelerator,
Expand Down
66 changes: 33 additions & 33 deletions tests/dataloaders/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,39 +132,39 @@ def test_anndataloader_distributed_sampler(num_processes: int, save_path: str):
)


@pytest.mark.multigpu
@pytest.mark.parametrize("num_processes", [1, 2])
def test_scanvi_with_distributed_sampler(num_processes: int, save_path: str):
adata = scvi.data.synthetic_iid()
SCANVI.setup_anndata(
adata,
"labels",
"label_0",
batch_key="batch",
)
file_path = save_path + "/dist_file"
if os.path.exists(file_path): # Check if the file exists
os.remove(file_path)
datasplitter_kwargs = {}
# Multi-GPU settings
datasplitter_kwargs["distributed_sampler"] = True
datasplitter_kwargs["drop_last"] = False
if num_processes == 1:
datasplitter_kwargs["distributed_sampler"] = False
model = SCANVI(adata, n_latent=10)

# initializes the distributed backend that takes care of synchronizing processes
torch.distributed.init_process_group(
"nccl", # backend that works on all systems
init_method=f"file://{save_path}/dist_file",
rank=0,
world_size=num_processes,
store=None,
)

model.train(1, datasplitter_kwargs=datasplitter_kwargs)

torch.distributed.destroy_process_group()
# @pytest.mark.multigpu
# @pytest.mark.parametrize("num_processes", [1, 2])
# def test_scanvi_with_distributed_sampler(num_processes: int, save_path: str):
# adata = scvi.data.synthetic_iid()
# SCANVI.setup_anndata(
# adata,
# "labels",
# "label_0",
# batch_key="batch",
# )
# file_path = save_path + "/dist_file"
# if os.path.exists(file_path): # Check if the file exists
# os.remove(file_path)
# datasplitter_kwargs = {}
# # Multi-GPU settings
# datasplitter_kwargs["distributed_sampler"] = True
# datasplitter_kwargs["drop_last"] = False
# if num_processes == 1:
# datasplitter_kwargs["distributed_sampler"] = False
# model = SCANVI(adata, n_latent=10)
#
# # initializes the distributed backend that takes care of synchronizing processes
# torch.distributed.init_process_group(
# "nccl", # backend that works on all systems
# init_method=f"file://{save_path}/dist_file",
# rank=0,
# world_size=num_processes,
# store=None,
# )
#
# model.train(1, datasplitter_kwargs=datasplitter_kwargs)
#
# torch.distributed.destroy_process_group()


def test_anncollection(save_path: str):
Expand Down
8 changes: 4 additions & 4 deletions tests/external/mrvi_jax/test_jaxmrvi_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def model(adata: AnnData):


def test_jaxmrvi(model: MRVI, adata: AnnData, save_path: str):
model.get_local_sample_distances()
model.get_local_sample_distances(normalize_distances=True)
model.get_local_sample_distances(batch_size=16)
model.get_local_sample_distances(normalize_distances=True, batch_size=16)
model.get_latent_representation(give_z=False)
model.get_latent_representation(give_z=True)

Expand Down Expand Up @@ -197,7 +197,7 @@ def test_jaxmrvi_shrink_u(adata: AnnData, save_path: str):
)
model = MRVI(adata, n_latent=10, n_latent_u=5, backend="jax")
model.train(max_steps=2, train_size=0.5)
model.get_local_sample_distances()
model.get_local_sample_distances(batch_size=16)

assert model.get_latent_representation().shape == (
adata.shape[0],
Expand Down Expand Up @@ -233,7 +233,7 @@ def test_jaxmrvi_stratifications(adata_stratifications: AnnData, save_path: str)
model = MRVI(adata_stratifications, n_latent=10, backend="jax")
model.train(max_steps=2, train_size=0.5)

dists = model.get_local_sample_distances(groupby=["labels", "label_2"])
dists = model.get_local_sample_distances(groupby=["labels", "label_2"], batch_size=16)
cell_dists = dists["cell"]
assert cell_dists.shape == (adata_stratifications.shape[0], 15, 15)
ct_dists = dists["labels"]
Expand Down
Loading
Loading