Skip to content

[core] Unify validation_step_outputs to always return list-of-lists#15470

Open
XuesongYang wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-unify-multi-dataloader-modelPT
Open

[core] Unify validation_step_outputs to always return list-of-lists#15470
XuesongYang wants to merge 4 commits intoNVIDIA-NeMo:mainfrom
XuesongYang:xueyang/pr-unify-multi-dataloader-modelPT

Conversation

@XuesongYang
Copy link
Collaborator

@XuesongYang XuesongYang commented Mar 6, 2026

What does this PR do ?

Unify ModelPT.validation_step_outputs (and test_step_outputs) to always return a list of lists, so a single dataloader is simply the N=1 case and subclasses no longer need to branch on the output shape. Normalize both _validation_dl and _test_dl to Optional[List[DataLoader]] via their respective resolvers.

Collection: Core, ASR, TTS, Audio

Changelog

  • modelPT.py: validation_step_outputs / test_step_outputs properties always return [[] for _ in range(num_dl)]; on_validation_epoch_end / on_test_epoch_end use len() == 1 instead of isinstance(..., dict) for single-vs-multi dispatch; empty-output guard updated to all(len(o) == 0 for o in ...) since [[]] is truthy; empty dataloader buckets skipped in multi-DL loop
  • model_utils.py: resolve_validation_dataloaders and resolve_test_dataloaders wrap bare DataLoader into [DataLoader] at both single-value paths, normalizing _validation_dl and _test_dl to Optional[List[DataLoader]]
  • modelPT.py (setup_multiple_validation_data): type annotation updated; isinstance guard simplified to truthiness check after normalization
  • 15 model files (ASR, TTS G2P, Audio): remove if/else branching in validation_step / test_step; always use self.validation_step_outputs[dataloader_idx].append(...)
  • transformer_bpe_models.py: remove isinstance(outputs[0], dict) normalization loop in multi_validation_epoch_end — base class now iterates dataloaders and calls it once per DL
  • audio_to_audio.py: simplify _get_num_dataloaders (both val and test) and logging callback setup after normalization
  • fastpitch.py, magpietts.py, magpietts_preference_optimization.py: add RuntimeError guard for len(validation_step_outputs) != 1; add early-return on empty outputs; use self.validation_step_outputs[0] consistently
  • ssl_models.py: fix EncDecMaskedTokenPredModel.test_step — was appending to validation_step_outputs instead of test_step_outputs
  • Test models (test_ema.py, check_for_ranks.py, test_ptl_stateless_timer.py): override multi_validation_epoch_end instead of on_validation_epoch_end; base class handles iteration, clearing, and per-DL prefix
  • New regression test test_empty_epoch_outputs_skip_multi_epoch_end: verify multi_validation/test_epoch_end is never called when all outputs are empty

Usage

No API changes for single-dataloader models — dataloader_idx=0 is the default. Subclasses should use the [dataloader_idx] indexing pattern:

# In validation_step:
def validation_step(self, batch, batch_idx, dataloader_idx=0):
    metrics = self.compute_metrics(batch)
    self.validation_step_outputs[dataloader_idx].append(metrics)
    return metrics

# In multi_validation_epoch_end (preferred over on_validation_epoch_end):
def multi_validation_epoch_end(self, outputs, dataloader_idx=0):
    avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    self.log("val_loss", avg_loss)

Copilot AI review requested due to automatic review settings March 6, 2026 02:11
@github-actions github-actions bot added core Changes to NeMo Core TTS ASR audio labels Mar 6, 2026
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR standardizes ModelPT.validation_step_outputs / test_step_outputs to use a consistent “list-of-lists” shape, simplifying subclass logic by removing single-vs-multi-dataloader branching and improving epoch-end dispatch/guards.

Changes:

  • Updated ModelPT epoch-end logic to dispatch based on len(outputs) (single vs multi dataloader) and to skip/guard empty per-dataloader outputs.
  • Normalized validation dataloader storage to List[DataLoader] in resolve_validation_dataloaders() and refactored many model validation_step/test_step implementations to always append via [dataloader_idx].
  • Updated unit tests and added a regression test to ensure multi_validation_epoch_end / multi_test_epoch_end are not called when all outputs are empty.

Reviewed changes

Copilot reviewed 22 out of 22 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
tests/core_ptl/test_ptl_stateless_timer.py Updates test model hooks to the new list-of-lists output shape and adds an empty-epoch regression test.
tests/core_ptl/check_for_ranks.py Switches test model to append outputs via validation_step_outputs[dataloader_idx] and uses multi_validation_epoch_end.
tests/collections/common/test_ema.py Updates validation/test steps to append via [dataloader_idx] and uses multi_validation_epoch_end.
nemo/utils/model_utils.py Wraps single validation dataloaders into a list to normalize _validation_dl shape.
nemo/core/classes/modelPT.py Implements the unified list-of-lists output cache, updates epoch-end dispatch and empty-output guards.
nemo/collections/tts/models/magpietts_preference_optimization.py Removes single-vs-multi branching in validation output accumulation; adjusts epoch-end logic for the new shape.
nemo/collections/tts/models/magpietts.py Updates validation accumulation and epoch-end collection to use validation_step_outputs[0] consistently.
nemo/collections/tts/models/fastpitch.py Updates validation accumulation and epoch-end processing to use the new output structure.
nemo/collections/tts/g2p/models/t5.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/tts/g2p/models/ctc.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/audio/models/audio_to_audio.py Removes branching on dataloader count; simplifies callback setup in line with _validation_dl normalization.
nemo/collections/asr/models/transformer_bpe_models.py Simplifies multi-epoch-end logic to assume per-dataloader outputs (base class iterates dataloaders).
nemo/collections/asr/models/ssl_models.py Removes branching on dataloader count and fixes test_step to append to test_step_outputs.
nemo/collections/asr/models/sortformer_diar_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/slu_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/rnnt_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/label_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/hybrid_rnnt_ctc_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/ctc_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/classification_models.py Removes branching on dataloader count; always appends via [dataloader_idx].
nemo/collections/asr/models/aed_multitask_models.py Removes branching on dataloader count; always appends via [dataloader_idx].

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 1710 to 1712
num_dl = len(self._validation_dl) if self._validation_dl else 1
self._validation_step_outputs = [[] for _ in range(num_dl)]

Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation_step_outputs computes num_dl using len(self._validation_dl) whenever _validation_dl is truthy. If a caller uses setup_validation_data() directly (common across models) then _validation_dl is typically a single DataLoader, so len(DataLoader) equals the number of batches. That will incorrectly create one output bucket per batch and can force on_validation_epoch_end() into the multi-dataloader branch (which expects _validation_names to be set), leading to crashes. Consider computing num_dl from the number of dataloaders only (e.g., len(_validation_dl) only when _validation_dl is a list/tuple of dataloaders; otherwise treat as 1, and keep the empty-list case consistent with the new [[]] semantics).

Copilot uses AI. Check for mistakes.
Comment on lines +1731 to +1732
num_dl = len(self._test_dl) if self._test_dl else 1
self._test_step_outputs = [[] for _ in range(num_dl)]
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_step_outputs uses len(self._test_dl) when _test_dl is a list/tuple. Since ModelPT.test_dataloader() sets _test_dl = [] when unset, this property can return [] (not a list-of-lists), which conflicts with the PR’s stated invariant and differs from validation_step_outputs (which returns [[]] for an empty list). Consider using the same logic as validation (treat empty list as the N=1 case, and only take len() when the container is non-empty).

Copilot uses AI. Check for mistakes.
Comment on lines 63 to 68
def _get_num_dataloaders(self, tag: str = 'val'):
if tag == 'val':
num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1
num_dataloaders = len(self._validation_dl) if self._validation_dl else 1
elif tag == 'test':
num_dataloaders = len(self._test_dl) if isinstance(self._test_dl, List) else 1
num_dataloaders = len(self._test_dl) if self._test_dl else 1
else:
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_num_dataloaders() now returns 1 when _validation_dl is an empty list. This changes the meaning from “number of configured dataloaders” to “at least 1”, which can cause _setup_metrics() to initialize metrics for a non-existent dataloader. Also, isinstance(self._test_dl, List) uses typing.List, which raises TypeError at runtime for isinstance checks; this should be replaced with a runtime type like (list, tuple) (and likely the same empty-list handling as for validation).

Copilot uses AI. Check for mistakes.
validation_step_outputs and test_step_outputs now always return a list
of lists (one inner list per dataloader), eliminating if/else branching
in every subclass that handles single-vs-multi dataloader shapes.

- validation_step_outputs property: returns [[] for _ in range(num_dl)]
- on_validation/test_epoch_end: len()==1 dispatch, all(len(o)==0 ...)
  empty guard, skip empty DL buckets in multi-DL loop
- Normalize _validation_dl to Optional[List[DataLoader]] in resolver
- 15 model files: self.validation_step_outputs[dataloader_idx].append()
- TTS models: RuntimeError guard for single-DL assumption
- Test models: override multi_validation_epoch_end, not on_*_epoch_end
- Bug fix: ssl_models test_step appended to wrong outputs list
- New test: empty outputs skip multi_epoch_end

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Made-with: Cursor
Same wrapping pattern as _validation_dl: wrap bare DataLoader into
[DataLoader] at both single-value paths in resolve_test_dataloaders.
Simplify isinstance guards in test_step_outputs and _get_num_dataloaders.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
@XuesongYang XuesongYang force-pushed the xueyang/pr-unify-multi-dataloader-modelPT branch from 53da4f2 to 75fc50d Compare March 9, 2026 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants