diff --git a/nemo/collections/asr/models/aed_multitask_models.py b/nemo/collections/asr/models/aed_multitask_models.py index d20860203e7e..cbbf6dfd6982 100644 --- a/nemo/collections/asr/models/aed_multitask_models.py +++ b/nemo/collections/asr/models/aed_multitask_models.py @@ -892,18 +892,12 @@ def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, datalo def validation_step(self, batch, batch_idx, dataloader_idx=0): metrics = self.validation_pass(batch, batch_idx, dataloader_idx, eval_mode="val") - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(metrics) - else: - self.validation_step_outputs.append(metrics) + self.validation_step_outputs[dataloader_idx].append(metrics) return metrics def test_step(self, batch, batch_idx, dataloader_idx=0): metrics = self.validation_pass(batch, batch_idx, dataloader_idx, eval_mode="test") - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(metrics) - else: - self.test_step_outputs.append(metrics) + self.test_step_outputs[dataloader_idx].append(metrics) return metrics def test_dataloader(self): diff --git a/nemo/collections/asr/models/classification_models.py b/nemo/collections/asr/models/classification_models.py index f15665fa20cb..e89a910e984a 100644 --- a/nemo/collections/asr/models/classification_models.py +++ b/nemo/collections/asr/models/classification_models.py @@ -1272,15 +1272,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = } if tag == 'val': - if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(output) - else: - self.validation_step_outputs.append(output) + self.validation_step_outputs[dataloader_idx].append(output) else: - if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(output) - else: - self.test_step_outputs.append(output) + self.test_step_outputs[dataloader_idx].append(output) return output def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'): diff --git a/nemo/collections/asr/models/ctc_models.py b/nemo/collections/asr/models/ctc_models.py index 6510fe8a4afa..ea62305cbede 100644 --- a/nemo/collections/asr/models/ctc_models.py +++ b/nemo/collections/asr/models/ctc_models.py @@ -665,10 +665,7 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0): metrics = self.validation_pass(batch, batch_idx, dataloader_idx) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(metrics) - else: - self.validation_step_outputs.append(metrics) + self.validation_step_outputs[dataloader_idx].append(metrics) return metrics def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): @@ -684,10 +681,7 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): def test_step(self, batch, batch_idx, dataloader_idx=0): logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(test_logs) - else: - self.test_step_outputs.append(test_logs) + self.test_step_outputs[dataloader_idx].append(test_logs) return test_logs def test_dataloader(self): diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py index 992de84fc7a8..8742ba9d7966 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py @@ -871,20 +871,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0): tensorboard_logs = self.validation_pass(batch, batch_idx, dataloader_idx) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(tensorboard_logs) - else: - self.validation_step_outputs.append(tensorboard_logs) - + self.validation_step_outputs[dataloader_idx].append(tensorboard_logs) return tensorboard_logs def test_step(self, batch, batch_idx, dataloader_idx=0): logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(test_logs) - else: - self.test_step_outputs.append(test_logs) + self.test_step_outputs[dataloader_idx].append(test_logs) return test_logs def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 5b0d1a98743a..d03c70d16742 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -632,20 +632,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): def validation_step(self, batch, batch_idx, dataloader_idx=0): tensorboard_logs = self.validation_pass(batch, batch_idx, dataloader_idx) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(tensorboard_logs) - else: - self.validation_step_outputs.append(tensorboard_logs) - + self.validation_step_outputs[dataloader_idx].append(tensorboard_logs) return tensorboard_logs def test_step(self, batch, batch_idx, dataloader_idx=0): logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(test_logs) - else: - self.test_step_outputs.append(test_logs) + self.test_step_outputs[dataloader_idx].append(test_logs) return test_logs def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 9300f7cfc897..bb8aff0569cb 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -439,15 +439,9 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str = f'{tag}_acc_macro_stats': stats, } if tag == 'val': - if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(output) - else: - self.validation_step_outputs.append(output) + self.validation_step_outputs[dataloader_idx].append(output) else: - if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(output) - else: - self.test_step_outputs.append(output) + self.test_step_outputs[dataloader_idx].append(output) return output @@ -478,15 +472,9 @@ def pair_evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: s } if tag == 'val': - if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(output) - else: - self.validation_step_outputs.append(output) + self.validation_step_outputs[dataloader_idx].append(output) else: - if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(output) - else: - self.test_step_outputs.append(output) + self.test_step_outputs[dataloader_idx].append(output) return output diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index ef116cf2d6b5..aa29b6db62d5 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -903,19 +903,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0): metrics = self.validation_pass(batch, batch_idx, dataloader_idx) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(metrics) - else: - self.validation_step_outputs.append(metrics) + self.validation_step_outputs[dataloader_idx].append(metrics) return metrics def test_step(self, batch, batch_idx, dataloader_idx=0): logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(test_logs) - else: - self.test_step_outputs.append(test_logs) + self.test_step_outputs[dataloader_idx].append(test_logs) return test_logs def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): diff --git a/nemo/collections/asr/models/slu_models.py b/nemo/collections/asr/models/slu_models.py index c599b7f4272a..8100ad9b622a 100644 --- a/nemo/collections/asr/models/slu_models.py +++ b/nemo/collections/asr/models/slu_models.py @@ -336,19 +336,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0): metrics = self.validation_pass(batch, batch_idx, dataloader_idx) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(metrics) - else: - self.validation_step_outputs.append(metrics) + self.validation_step_outputs[dataloader_idx].append(metrics) return metrics def test_step(self, batch, batch_idx, dataloader_idx=0): logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx) test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()} - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(test_logs) - else: - self.test_step_outputs.append(test_logs) + self.test_step_outputs[dataloader_idx].append(test_logs) return test_logs def test_dataloader(self): diff --git a/nemo/collections/asr/models/sortformer_diar_models.py b/nemo/collections/asr/models/sortformer_diar_models.py index e0046b929fc6..2d185ab352a8 100644 --- a/nemo/collections/asr/models/sortformer_diar_models.py +++ b/nemo/collections/asr/models/sortformer_diar_models.py @@ -967,10 +967,7 @@ def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): audio_signal_length=audio_signal_length, ) val_metrics = self._get_aux_validation_evaluations(preds, targets, target_lens) - if isinstance(self.trainer.val_dataloaders, list) and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(val_metrics) - else: - self.validation_step_outputs.append(val_metrics) + self.validation_step_outputs[dataloader_idx].append(val_metrics) return val_metrics def test_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0): diff --git a/nemo/collections/asr/models/ssl_models.py b/nemo/collections/asr/models/ssl_models.py index 6e149c3c17b8..3c1f3392ff49 100644 --- a/nemo/collections/asr/models/ssl_models.py +++ b/nemo/collections/asr/models/ssl_models.py @@ -596,10 +596,7 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0): def validation_step(self, batch, batch_idx, dataloader_idx=0): metrics = self.validation_pass(batch, batch_idx, dataloader_idx) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(metrics) - else: - self.validation_step_outputs.append(metrics) + self.validation_step_outputs[dataloader_idx].append(metrics) return metrics def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): @@ -780,18 +777,12 @@ def inference_pass(self, batch, batch_idx=0, dataloader_idx=0, mode='val', apply def validation_step(self, batch, batch_idx=0, dataloader_idx=0): metrics = self.inference_pass(batch, batch_idx, dataloader_idx, apply_mask=True) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(metrics) - else: - self.validation_step_outputs.append(metrics) + self.validation_step_outputs[dataloader_idx].append(metrics) return metrics def test_step(self, batch, batch_idx=0, dataloader_idx=0): metrics = self.inference_pass(batch, batch_idx, dataloader_idx, mode="test", apply_mask=True) - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(metrics) - else: - self.validation_step_outputs.append(metrics) + self.test_step_outputs[dataloader_idx].append(metrics) return metrics def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0): diff --git a/nemo/collections/asr/models/transformer_bpe_models.py b/nemo/collections/asr/models/transformer_bpe_models.py index 4692cb662b4b..3f72aba6aee8 100644 --- a/nemo/collections/asr/models/transformer_bpe_models.py +++ b/nemo/collections/asr/models/transformer_bpe_models.py @@ -13,9 +13,7 @@ # limitations under the License. import itertools -import json import os -import tempfile from math import ceil from typing import Any, Dict, List, Optional, Union @@ -26,7 +24,6 @@ from omegaconf import DictConfig, OmegaConf, open_dict from torch.utils.data import DataLoader from torchmetrics.text import SacreBLEUScore -from tqdm.auto import tqdm from nemo.collections.asr.data import audio_to_text_dataset from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs @@ -483,7 +480,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"): output_dict = {f'{eval_mode}_loss': transf_loss, 'translations': translations, 'ground_truths': ground_truths} - self.validation_step_outputs.append(output_dict) + self.validation_step_outputs[dataloader_idx].append(output_dict) return output_dict @@ -498,48 +495,44 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, eval_mode if not outputs: return - if isinstance(outputs[0], dict): - outputs = [outputs] + eval_loss = getattr(self, 'val_loss').compute() + translations = list(itertools.chain(*[x['translations'] for x in outputs])) + ground_truths = list(itertools.chain(*[x['ground_truths'] for x in outputs])) - for output in outputs: - eval_loss = getattr(self, 'val_loss').compute() - translations = list(itertools.chain(*[x['translations'] for x in output])) - ground_truths = list(itertools.chain(*[x['ground_truths'] for x in output])) + # Gather translations and ground truths from all workers + tr_and_gt = [None for _ in range(self.world_size)] + # we also need to drop pairs where ground truth is an empty string + if self.world_size > 1: + dist.all_gather_object( + tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] + ) + else: + tr_and_gt[0] = [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] - # Gather translations and ground truths from all workers - tr_and_gt = [None for _ in range(self.world_size)] - # we also need to drop pairs where ground truth is an empty string - if self.world_size > 1: - dist.all_gather_object( - tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] - ) - else: - tr_and_gt[0] = [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != ''] - - if self.global_rank == 0: - _translations = [] - _ground_truths = [] - for rank in range(0, self.world_size): - _translations += [t for (t, g) in tr_and_gt[rank]] - _ground_truths += [g for (t, g) in tr_and_gt[rank]] - - sacre_bleu = SacreBLEUScore()(_translations, [[x] for x in _ground_truths]).item() - sb_score = sacre_bleu * self.world_size - - wer_scores, wer_words = 0, 0 - for h, r in zip(_translations, _ground_truths): - wer_words += len(r.split()) - wer_scores += editdistance.eval(h.split(), r.split()) - wer_score = 1.0 * wer_scores * self.world_size / wer_words - - else: - sb_score = 0.0 - wer_score = 0.0 - - self.log(f"{eval_mode}_loss", eval_loss, sync_dist=True) - self.log(f"{eval_mode}_sacreBLEU", sb_score, sync_dist=True) - self.log(f"{eval_mode}_WER", wer_score, sync_dist=True) - self.val_loss.reset() + if self.global_rank == 0: + _translations = [] + _ground_truths = [] + for rank in range(0, self.world_size): + _translations += [t for (t, g) in tr_and_gt[rank]] + _ground_truths += [g for (t, g) in tr_and_gt[rank]] + + sacre_bleu = SacreBLEUScore()(_translations, [[x] for x in _ground_truths]).item() + sb_score = sacre_bleu * self.world_size + + wer_scores, wer_words = 0, 0 + for h, r in zip(_translations, _ground_truths): + wer_words += len(r.split()) + wer_scores += editdistance.eval(h.split(), r.split()) + wer_score = 1.0 * wer_scores * self.world_size / wer_words + + else: + sb_score = 0.0 + wer_score = 0.0 + + self.log(f"{eval_mode}_loss", eval_loss, sync_dist=True) + self.log(f"{eval_mode}_sacreBLEU", sb_score, sync_dist=True) + self.log(f"{eval_mode}_WER", wer_score, sync_dist=True) + self.val_loss.reset() def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0): return self.multi_validation_epoch_end(outputs, dataloader_idx, eval_mode="test") diff --git a/nemo/collections/audio/models/audio_to_audio.py b/nemo/collections/audio/models/audio_to_audio.py index 28109f27b7f2..0d5f937fdd1c 100644 --- a/nemo/collections/audio/models/audio_to_audio.py +++ b/nemo/collections/audio/models/audio_to_audio.py @@ -62,9 +62,9 @@ def _setup_loss(self): def _get_num_dataloaders(self, tag: str = 'val'): if tag == 'val': - num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1 + num_dataloaders = len(self._validation_dl) if self._validation_dl else 1 elif tag == 'test': - num_dataloaders = len(self._test_dl) if isinstance(self._test_dl, List) else 1 + num_dataloaders = len(self._test_dl) if self._test_dl else 1 else: raise ValueError(f'Unexpected tag {tag}.') @@ -144,18 +144,12 @@ def on_test_start(self): def validation_step(self, batch, batch_idx, dataloader_idx: int = 0): output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'val') - if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(output_dict) - else: - self.validation_step_outputs.append(output_dict) + self.validation_step_outputs[dataloader_idx].append(output_dict) return output_dict def test_step(self, batch, batch_idx, dataloader_idx=0): output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'test') - if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(output_dict) - else: - self.test_step_outputs.append(output_dict) + self.test_step_outputs[dataloader_idx].append(output_dict) return output_dict def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'): @@ -514,10 +508,7 @@ def configure_callbacks(self): log_callbacks = [] from nemo.collections.audio.parts.utils.callbacks import SpeechEnhancementLoggingCallback - if isinstance(self._validation_dl, List): - data_loaders = self._validation_dl - else: - data_loaders = [self._validation_dl] + data_loaders = self._validation_dl if self._validation_dl else [] for data_loader_idx, data_loader in enumerate(data_loaders): log_callbacks.append( diff --git a/nemo/collections/tts/g2p/models/ctc.py b/nemo/collections/tts/g2p/models/ctc.py index 3248774de571..73b0e64988e7 100644 --- a/nemo/collections/tts/g2p/models/ctc.py +++ b/nemo/collections/tts/g2p/models/ctc.py @@ -246,15 +246,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"): } if split == 'val': - if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(loss) - else: - self.validation_step_outputs.append(loss) + self.validation_step_outputs[dataloader_idx].append(loss) elif split == 'test': - if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(loss) - else: - self.test_step_outputs.append(loss) + self.test_step_outputs[dataloader_idx].append(loss) return loss diff --git a/nemo/collections/tts/g2p/models/t5.py b/nemo/collections/tts/g2p/models/t5.py index 05f9ea080a80..9476c8272d84 100644 --- a/nemo/collections/tts/g2p/models/t5.py +++ b/nemo/collections/tts/g2p/models/t5.py @@ -190,15 +190,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, split="val"): per = word_error_rate(hypotheses=generated_str, references=labels_str, use_cer=True) output = {f"{split}_loss": val_loss, 'per': per} if split == 'val': - if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: - self.validation_step_outputs[dataloader_idx].append(output) - else: - self.validation_step_outputs.append(output) + self.validation_step_outputs[dataloader_idx].append(output) else: - if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: - self.test_step_outputs[dataloader_idx].append(output) - else: - self.test_step_outputs.append(output) + self.test_step_outputs[dataloader_idx].append(output) return output def test_step(self, batch, batch_idx, dataloader_idx=0): diff --git a/nemo/collections/tts/models/fastpitch.py b/nemo/collections/tts/models/fastpitch.py index 6c8c9526f95b..0903497e4d99 100644 --- a/nemo/collections/tts/models/fastpitch.py +++ b/nemo/collections/tts/models/fastpitch.py @@ -483,7 +483,7 @@ def training_step(self, batch, batch_idx): return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): attn_prior, durs, speaker, energy, reference_audio, reference_audio_len = ( None, None, @@ -563,11 +563,21 @@ def validation_step(self, batch, batch_idx): "mel_target": mels if batch_idx == 0 else None, "mel_pred": mels_pred if batch_idx == 0 else None, } - self.validation_step_outputs.append(val_outputs) + self.validation_step_outputs[dataloader_idx].append(val_outputs) return val_outputs def on_validation_epoch_end(self): - collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() + if len(self.validation_step_outputs) != 1: + raise RuntimeError( + "FastPitchModel.on_validation_epoch_end only supports a single validation dataloader. " + "Please override multi_validation_epoch_end for multi-dataloader validation." + ) + + outputs = self.validation_step_outputs[0] + if not outputs: + return + + collect = lambda key: torch.stack([x[key] for x in outputs]).mean() val_loss = collect("val_loss") mel_loss = collect("mel_loss") dur_loss = collect("dur_loss") @@ -576,11 +586,11 @@ def on_validation_epoch_end(self): self.log("val_mel_loss", mel_loss, sync_dist=True) self.log("val_dur_loss", dur_loss, sync_dist=True) self.log("val_pitch_loss", pitch_loss, sync_dist=True) - if self.validation_step_outputs[0]["energy_loss"] is not None: + if outputs[0]["energy_loss"] is not None: energy_loss = collect("energy_loss") self.log("val_energy_loss", energy_loss, sync_dist=True) - _, _, _, _, _, spec_target, spec_predict = self.validation_step_outputs[0].values() + _, _, _, _, _, spec_target, spec_predict = outputs[0].values() if self.log_images and isinstance(self.logger, TensorBoardLogger): self.tb_logger.add_image( @@ -597,7 +607,7 @@ def on_validation_epoch_end(self): dataformats="HWC", ) self.log_train_images = True - self.validation_step_outputs.clear() # free memory) + self.validation_step_outputs[0].clear() # free memory def _setup_train_dataloader(self, cfg): phon_mode = contextlib.nullcontext() diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index d583cacadd74..41e1d9319e89 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -450,6 +450,16 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): ) def on_validation_epoch_end(self): + if len(self.validation_step_outputs) != 1: + raise RuntimeError( + "MagpieTTSModelDPO.on_validation_epoch_end only supports a single validation dataloader. " + "Please override multi_validation_epoch_end for multi-dataloader validation." + ) + + outputs = self.validation_step_outputs[0] + if not outputs: + return + def collect(key): values = [] for val_outputs in self.validation_step_outputs: @@ -992,6 +1002,16 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0): ) def on_validation_epoch_end(self): + if len(self.validation_step_outputs) != 1: + raise RuntimeError( + "MagpieTTSModelOnlinePO.on_validation_epoch_end only supports a single validation dataloader. " + "Please override multi_validation_epoch_end for multi-dataloader validation." + ) + + outputs = self.validation_step_outputs[0] + if not outputs: + return + def collect(key): values = [] for val_outputs in self.validation_step_outputs: diff --git a/nemo/core/classes/modelPT.py b/nemo/core/classes/modelPT.py index 027ca47a4e82..2cc113a40bbd 100644 --- a/nemo/core/classes/modelPT.py +++ b/nemo/core/classes/modelPT.py @@ -585,7 +585,7 @@ def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict # Set some placeholder overriden by helper method self._val_dl_idx: int = 0 self._validation_names: Optional[List[str]] = None - self._validation_dl: Optional[torch.utils.data.DataLoader] = None + self._validation_dl: Optional[List[torch.utils.data.DataLoader]] = None # preserve config self._update_dataset_config(dataset_name='validation', config=val_data_config) @@ -597,7 +597,7 @@ def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict self._multi_dataset_mode = False if self._validation_names is None: - if self._validation_dl is not None and type(self._validation_dl) in [list, tuple]: + if self._validation_dl is not None: self._validation_names = ['val_{}_'.format(idx) for idx in range(len(self._validation_dl))] def setup_multiple_test_data(self, test_data_config: Union[DictConfig, Dict]): @@ -979,18 +979,18 @@ def on_validation_epoch_end(self, sync_metrics: bool = False) -> Optional[Dict[s A dictionary containing the union of all items from individual data_loaders, along with merged logs from all data loaders. """ - # Case where we dont provide data loaders - if self.validation_step_outputs is not None and len(self.validation_step_outputs) == 0: + # Case where we dont provide data loaders, or all dataloaders produced no batches. + if not self.validation_step_outputs or all(len(outputs) == 0 for outputs in self.validation_step_outputs): return {} # Case where we provide exactly 1 data loader - if isinstance(self.validation_step_outputs[0], dict): - output_dict = self.multi_validation_epoch_end(self.validation_step_outputs, dataloader_idx=0) + if len(self.validation_step_outputs) == 1: + output_dict = self.multi_validation_epoch_end(self.validation_step_outputs[0], dataloader_idx=0) if output_dict is not None and 'log' in output_dict: self.log_dict(output_dict.pop('log'), on_epoch=True, sync_dist=sync_metrics) - self.validation_step_outputs.clear() # free memory + self.validation_step_outputs[0].clear() # free memory return output_dict else: # Case where we provide more than 1 data loader @@ -998,6 +998,9 @@ def on_validation_epoch_end(self, sync_metrics: bool = False) -> Optional[Dict[s # The output is a list of list of dicts, outer list corresponds to dataloader idx for dataloader_idx, val_outputs in enumerate(self.validation_step_outputs): + if len(val_outputs) == 0: + continue + # Get prefix and dispatch call to multi epoch end dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx) dataloader_logs = self.multi_validation_epoch_end(val_outputs, dataloader_idx=dataloader_idx) @@ -1075,18 +1078,18 @@ def on_test_epoch_end(self) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: A dictionary containing the union of all items from individual data_loaders, along with merged logs from all data loaders. """ - # Case where we dont provide data loaders - if self.test_step_outputs is not None and len(self.test_step_outputs) == 0: + # Case where we dont provide data loaders, or all dataloaders produced no batches. + if not self.test_step_outputs or all(len(outputs) == 0 for outputs in self.test_step_outputs): return {} # Case where we provide exactly 1 data loader - if isinstance(self.test_step_outputs[0], dict): - output_dict = self.multi_test_epoch_end(self.test_step_outputs, dataloader_idx=0) + if len(self.test_step_outputs) == 1: + output_dict = self.multi_test_epoch_end(self.test_step_outputs[0], dataloader_idx=0) if output_dict is not None and 'log' in output_dict: self.log_dict(output_dict.pop('log'), on_epoch=True) - self.test_step_outputs.clear() # free memory + self.test_step_outputs[0].clear() # free memory return output_dict else: # Case where we provide more than 1 data loader @@ -1094,6 +1097,9 @@ def on_test_epoch_end(self) -> Optional[Dict[str, Dict[str, torch.Tensor]]]: # The output is a list of list of dicts, outer list corresponds to dataloader idx for dataloader_idx, test_outputs in enumerate(self.test_step_outputs): + if len(test_outputs) == 0: + continue + # Get prefix and dispatch call to multi epoch end dataloader_prefix = self.get_test_dataloader_prefix(dataloader_idx) dataloader_logs = self.multi_test_epoch_end(test_outputs, dataloader_idx=dataloader_idx) @@ -1692,26 +1698,20 @@ def hparams(self): @property def validation_step_outputs(self): """ - Cached outputs of validation_step. It can be a list of items (for single data loader) or a list of lists - (for multiple data loaders). + Cached outputs of validation_step. Always returns a list of lists, + where each inner list corresponds to one validation dataloader. Returns: - List of outputs of validation_step. + List of lists of outputs of validation_step. """ if self._validation_step_outputs is not None: return self._validation_step_outputs - # Initialize new output list - self._validation_step_outputs = [] - # Check len(self._validation_dl) > 1 as sometimes single dataloader can be in a - # list: [] when ds_item in config has 1 item passed in a list - if ( - self._validation_dl is not None - and isinstance(self._validation_dl, (list, tuple)) - and len(self._validation_dl) > 1 - ): - for _ in range(len(self._validation_dl)): - self._validation_step_outputs.append([]) + if isinstance(self._validation_dl, (list, tuple)) and len(self._validation_dl) > 0: + num_dl = len(self._validation_dl) + else: + num_dl = 1 + self._validation_step_outputs = [[] for _ in range(num_dl)] return self._validation_step_outputs @@ -1722,22 +1722,20 @@ def validation_step_outputs(self, value): @property def test_step_outputs(self): """ - Cached outputs of test_step. It can be a list of items (for single data loader) or a list of - lists (for multiple data loaders). + Cached outputs of test_step. Always returns a list of lists, + where each inner list corresponds to one test dataloader. Returns: - List of outputs of test_step. + List of lists of outputs of test_step. """ if self._test_step_outputs is not None: return self._test_step_outputs - # Initialize new output list - self._test_step_outputs = [] - # Check len(self._test_dl) > 1 as sometimes single dataloader can be in a list: [] - # when ds_item in config has 1 item passed in a list - if self._test_dl is not None and isinstance(self._test_dl, (list, tuple)) and len(self._test_dl) > 1: - for _ in range(len(self._test_dl)): - self._test_step_outputs.append([]) + if isinstance(self._test_dl, (list, tuple)) and len(self._test_dl) > 0: + num_dl = len(self._test_dl) + else: + num_dl = 1 + self._test_step_outputs = [[] for _ in range(num_dl)] return self._test_step_outputs diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index ebff99d6c160..150893a2f798 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -318,6 +318,8 @@ def resolve_validation_dataloaders(model: 'ModelPT'): ) model.setup_validation_data(cfg.validation_ds) + if model._validation_dl is not None and not isinstance(model._validation_dl, (list, tuple)): + model._validation_dl = [model._validation_dl] return ds_values = cfg.validation_ds[ds_key] @@ -355,6 +357,8 @@ def resolve_validation_dataloaders(model: 'ModelPT'): else: model.setup_validation_data(cfg.validation_ds) + if model._validation_dl is not None and not isinstance(model._validation_dl, (list, tuple)): + model._validation_dl = [model._validation_dl] ds_names = cfg.validation_ds.get('name', None) if ds_names is not None: if not isinstance(ds_names, str): @@ -410,6 +414,8 @@ def resolve_test_dataloaders(model: 'ModelPT'): ) model.setup_test_data(cfg.test_ds) + if model._test_dl is not None and not isinstance(model._test_dl, (list, tuple)): + model._test_dl = [model._test_dl] return ds_values = cfg.test_ds[ds_key] @@ -447,6 +453,8 @@ def resolve_test_dataloaders(model: 'ModelPT'): else: model.setup_test_data(cfg.test_ds) + if model._test_dl is not None and not isinstance(model._test_dl, (list, tuple)): + model._test_dl = [model._test_dl] ds_names = cfg.test_ds.get('name', None) if ds_names is not None: if not isinstance(ds_names, str): diff --git a/tests/collections/common/test_ema.py b/tests/collections/common/test_ema.py index fa8bc968b049..94432a8904b7 100644 --- a/tests/collections/common/test_ema.py +++ b/tests/collections/common/test_ema.py @@ -91,14 +91,14 @@ def forward(self, batch): def training_step(self, batch, batch_idx): return self(batch) - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): loss = self(batch) - self.validation_step_outputs.append(loss) + self.validation_step_outputs[dataloader_idx].append(loss) return loss - def test_step(self, batch, batch_idx): + def test_step(self, batch, batch_idx, dataloader_idx=0): loss = self(batch) - self.test_step_outputs.append(loss) + self.test_step_outputs[dataloader_idx].append(loss) return loss def configure_optimizers(self): @@ -116,9 +116,8 @@ def setup_validation_data(self, val_data_config: Union[DictConfig, Dict]): def setup_test_data(self, val_data_config: Union[DictConfig, Dict]): pass - def on_validation_epoch_end(self): - self.log("val_loss", torch.stack(self.validation_step_outputs).mean()) - self.validation_step_outputs.clear() # free memory + def multi_validation_epoch_end(self, outputs, dataloader_idx=0): + self.log("val_loss", torch.stack(outputs).mean()) class TestEMAConfig: diff --git a/tests/core_ptl/check_for_ranks.py b/tests/core_ptl/check_for_ranks.py index dfbc05166c5a..dd9ccc9b4499 100644 --- a/tests/core_ptl/check_for_ranks.py +++ b/tests/core_ptl/check_for_ranks.py @@ -57,9 +57,9 @@ def predict_dataloader(self): def forward(self, batch): return batch.mean() - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): loss = self(batch) - self.validation_step_outputs.append(loss) + self.validation_step_outputs[dataloader_idx].append(loss) return loss def training_step(self, batch, batch_idx): @@ -74,9 +74,8 @@ def setup_training_data(self): def setup_validation_data(self): pass - def on_validation_epoch_end(self): - self.log("val_loss", torch.stack(self.validation_step_outputs).mean()) - self.validation_step_outputs.clear() # free memory + def multi_validation_epoch_end(self, outputs, dataloader_idx=0): + self.log("val_loss", torch.stack(outputs).mean()) def instantiate_multinode_ddp_if_possible(): diff --git a/tests/core_ptl/test_ptl_stateless_timer.py b/tests/core_ptl/test_ptl_stateless_timer.py index 5cfbbda39bbf..ba33591607dc 100644 --- a/tests/core_ptl/test_ptl_stateless_timer.py +++ b/tests/core_ptl/test_ptl_stateless_timer.py @@ -59,9 +59,9 @@ def predict_dataloader(self): def forward(self, batch): return (self.l1(batch) - batch.mean(dim=1)).mean() - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): loss = (self.l1(batch) - batch.mean(dim=1)).mean() - self.validation_step_outputs.append(loss) + self.validation_step_outputs[dataloader_idx].append(loss) return loss def training_step(self, batch, batch_idx): @@ -76,11 +76,8 @@ def setup_training_data(self): def setup_validation_data(self): pass - def on_validation_epoch_end(self): - if not self.validation_step_outputs: - return - self.log("val_loss", torch.stack(self.validation_step_outputs).mean(), sync_dist=True) - self.validation_step_outputs.clear() # free memory + def multi_validation_epoch_end(self, outputs, dataloader_idx=0): + self.log("val_loss", torch.stack(outputs).mean(), sync_dist=True) class TestStatelessTimer: