Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,18 +892,12 @@ def validation_pass(self, batch: PromptedAudioToTextMiniBatch, batch_idx, datalo

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx, eval_mode="val")
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
self.validation_step_outputs[dataloader_idx].append(metrics)
return metrics

def test_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx, eval_mode="test")
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(metrics)
else:
self.test_step_outputs.append(metrics)
self.test_step_outputs[dataloader_idx].append(metrics)
return metrics

def test_dataloader(self):
Expand Down
10 changes: 2 additions & 8 deletions nemo/collections/asr/models/classification_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1272,15 +1272,9 @@ def validation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
}

if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
self.validation_step_outputs[dataloader_idx].append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)
self.test_step_outputs[dataloader_idx].append(output)
return output

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
Expand Down
10 changes: 2 additions & 8 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,10 +665,7 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0):

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
self.validation_step_outputs[dataloader_idx].append(metrics)
return metrics

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
Expand All @@ -684,10 +681,7 @@ def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
self.test_step_outputs[dataloader_idx].append(test_logs)
return test_logs

def test_dataloader(self):
Expand Down
11 changes: 2 additions & 9 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,20 +871,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx):

def validation_step(self, batch, batch_idx, dataloader_idx=0):
tensorboard_logs = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(tensorboard_logs)
else:
self.validation_step_outputs.append(tensorboard_logs)

self.validation_step_outputs[dataloader_idx].append(tensorboard_logs)
return tensorboard_logs

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
self.test_step_outputs[dataloader_idx].append(test_logs)
return test_logs

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
Expand Down
11 changes: 2 additions & 9 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,20 +632,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx):

def validation_step(self, batch, batch_idx, dataloader_idx=0):
tensorboard_logs = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(tensorboard_logs)
else:
self.validation_step_outputs.append(tensorboard_logs)

self.validation_step_outputs[dataloader_idx].append(tensorboard_logs)
return tensorboard_logs

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
self.test_step_outputs[dataloader_idx].append(test_logs)
return test_logs

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
Expand Down
20 changes: 4 additions & 16 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,15 +439,9 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
f'{tag}_acc_macro_stats': stats,
}
if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
self.validation_step_outputs[dataloader_idx].append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)
self.test_step_outputs[dataloader_idx].append(output)

return output

Expand Down Expand Up @@ -478,15 +472,9 @@ def pair_evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: s
}

if tag == 'val':
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output)
else:
self.validation_step_outputs.append(output)
self.validation_step_outputs[dataloader_idx].append(output)
else:
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output)
else:
self.test_step_outputs.append(output)
self.test_step_outputs[dataloader_idx].append(output)

return output

Expand Down
10 changes: 2 additions & 8 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,19 +903,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0):

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
self.validation_step_outputs[dataloader_idx].append(metrics)
return metrics

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
self.test_step_outputs[dataloader_idx].append(test_logs)
return test_logs

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
Expand Down
10 changes: 2 additions & 8 deletions nemo/collections/asr/models/slu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,19 +336,13 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0):

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
self.validation_step_outputs[dataloader_idx].append(metrics)
return metrics

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
self.test_step_outputs[dataloader_idx].append(test_logs)
return test_logs

def test_dataloader(self):
Expand Down
5 changes: 1 addition & 4 deletions nemo/collections/asr/models/sortformer_diar_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,10 +967,7 @@ def validation_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0):
audio_signal_length=audio_signal_length,
)
val_metrics = self._get_aux_validation_evaluations(preds, targets, target_lens)
if isinstance(self.trainer.val_dataloaders, list) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(val_metrics)
else:
self.validation_step_outputs.append(val_metrics)
self.validation_step_outputs[dataloader_idx].append(val_metrics)
return val_metrics

def test_step(self, batch: list, batch_idx: int, dataloader_idx: int = 0):
Expand Down
15 changes: 3 additions & 12 deletions nemo/collections/asr/models/ssl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,10 +596,7 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0):

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
self.validation_step_outputs[dataloader_idx].append(metrics)
return metrics

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
Expand Down Expand Up @@ -780,18 +777,12 @@ def inference_pass(self, batch, batch_idx=0, dataloader_idx=0, mode='val', apply

def validation_step(self, batch, batch_idx=0, dataloader_idx=0):
metrics = self.inference_pass(batch, batch_idx, dataloader_idx, apply_mask=True)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
self.validation_step_outputs[dataloader_idx].append(metrics)
return metrics

def test_step(self, batch, batch_idx=0, dataloader_idx=0):
metrics = self.inference_pass(batch, batch_idx, dataloader_idx, mode="test", apply_mask=True)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
self.test_step_outputs[dataloader_idx].append(metrics)
return metrics

def multi_validation_epoch_end(self, outputs: list, dataloader_idx: int = 0):
Expand Down
81 changes: 37 additions & 44 deletions nemo/collections/asr/models/transformer_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
# limitations under the License.

import itertools
import json
import os
import tempfile
from math import ceil
from typing import Any, Dict, List, Optional, Union

Expand All @@ -26,7 +24,6 @@
from omegaconf import DictConfig, OmegaConf, open_dict
from torch.utils.data import DataLoader
from torchmetrics.text import SacreBLEUScore
from tqdm.auto import tqdm

from nemo.collections.asr.data import audio_to_text_dataset
from nemo.collections.asr.data.audio_to_text_dali import DALIOutputs
Expand Down Expand Up @@ -483,7 +480,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0, eval_mode="val"):

output_dict = {f'{eval_mode}_loss': transf_loss, 'translations': translations, 'ground_truths': ground_truths}

self.validation_step_outputs.append(output_dict)
self.validation_step_outputs[dataloader_idx].append(output_dict)

return output_dict

Expand All @@ -498,48 +495,44 @@ def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0, eval_mode
if not outputs:
return

if isinstance(outputs[0], dict):
outputs = [outputs]
eval_loss = getattr(self, 'val_loss').compute()
translations = list(itertools.chain(*[x['translations'] for x in outputs]))
ground_truths = list(itertools.chain(*[x['ground_truths'] for x in outputs]))

for output in outputs:
eval_loss = getattr(self, 'val_loss').compute()
translations = list(itertools.chain(*[x['translations'] for x in output]))
ground_truths = list(itertools.chain(*[x['ground_truths'] for x in output]))
# Gather translations and ground truths from all workers
tr_and_gt = [None for _ in range(self.world_size)]
# we also need to drop pairs where ground truth is an empty string
if self.world_size > 1:
dist.all_gather_object(
tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != '']
)
else:
tr_and_gt[0] = [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != '']

# Gather translations and ground truths from all workers
tr_and_gt = [None for _ in range(self.world_size)]
# we also need to drop pairs where ground truth is an empty string
if self.world_size > 1:
dist.all_gather_object(
tr_and_gt, [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != '']
)
else:
tr_and_gt[0] = [(t, g) for (t, g) in zip(translations, ground_truths) if g.strip() != '']

if self.global_rank == 0:
_translations = []
_ground_truths = []
for rank in range(0, self.world_size):
_translations += [t for (t, g) in tr_and_gt[rank]]
_ground_truths += [g for (t, g) in tr_and_gt[rank]]

sacre_bleu = SacreBLEUScore()(_translations, [[x] for x in _ground_truths]).item()
sb_score = sacre_bleu * self.world_size

wer_scores, wer_words = 0, 0
for h, r in zip(_translations, _ground_truths):
wer_words += len(r.split())
wer_scores += editdistance.eval(h.split(), r.split())
wer_score = 1.0 * wer_scores * self.world_size / wer_words

else:
sb_score = 0.0
wer_score = 0.0

self.log(f"{eval_mode}_loss", eval_loss, sync_dist=True)
self.log(f"{eval_mode}_sacreBLEU", sb_score, sync_dist=True)
self.log(f"{eval_mode}_WER", wer_score, sync_dist=True)
self.val_loss.reset()
if self.global_rank == 0:
_translations = []
_ground_truths = []
for rank in range(0, self.world_size):
_translations += [t for (t, g) in tr_and_gt[rank]]
_ground_truths += [g for (t, g) in tr_and_gt[rank]]

sacre_bleu = SacreBLEUScore()(_translations, [[x] for x in _ground_truths]).item()
sb_score = sacre_bleu * self.world_size

wer_scores, wer_words = 0, 0
for h, r in zip(_translations, _ground_truths):
wer_words += len(r.split())
wer_scores += editdistance.eval(h.split(), r.split())
wer_score = 1.0 * wer_scores * self.world_size / wer_words

else:
sb_score = 0.0
wer_score = 0.0

self.log(f"{eval_mode}_loss", eval_loss, sync_dist=True)
self.log(f"{eval_mode}_sacreBLEU", sb_score, sync_dist=True)
self.log(f"{eval_mode}_WER", wer_score, sync_dist=True)
self.val_loss.reset()

def multi_test_epoch_end(self, outputs, dataloader_idx: int = 0):
return self.multi_validation_epoch_end(outputs, dataloader_idx, eval_mode="test")
Expand Down
19 changes: 5 additions & 14 deletions nemo/collections/audio/models/audio_to_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def _setup_loss(self):

def _get_num_dataloaders(self, tag: str = 'val'):
if tag == 'val':
num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1
num_dataloaders = len(self._validation_dl) if self._validation_dl else 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically I think you want

Suggested change
num_dataloaders = len(self._validation_dl) if self._validation_dl else 1
num_dataloaders = len(self._validation_dl) if self._validation_dl is not None else 1

and similar everywhere else - make it clear this is a None check

elif tag == 'test':
num_dataloaders = len(self._test_dl) if isinstance(self._test_dl, List) else 1
num_dataloaders = len(self._test_dl) if self._test_dl else 1
else:
Comment on lines 63 to 68
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_get_num_dataloaders() now returns 1 when _validation_dl is an empty list. This changes the meaning from “number of configured dataloaders” to “at least 1”, which can cause _setup_metrics() to initialize metrics for a non-existent dataloader. Also, isinstance(self._test_dl, List) uses typing.List, which raises TypeError at runtime for isinstance checks; this should be replaced with a runtime type like (list, tuple) (and likely the same empty-list handling as for validation).

Copilot uses AI. Check for mistakes.
raise ValueError(f'Unexpected tag {tag}.')

Expand Down Expand Up @@ -144,18 +144,12 @@ def on_test_start(self):

def validation_step(self, batch, batch_idx, dataloader_idx: int = 0):
output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'val')
if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(output_dict)
else:
self.validation_step_outputs.append(output_dict)
self.validation_step_outputs[dataloader_idx].append(output_dict)
return output_dict

def test_step(self, batch, batch_idx, dataloader_idx=0):
output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'test')
if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(output_dict)
else:
self.test_step_outputs.append(output_dict)
self.test_step_outputs[dataloader_idx].append(output_dict)
return output_dict

def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
Expand Down Expand Up @@ -514,10 +508,7 @@ def configure_callbacks(self):
log_callbacks = []
from nemo.collections.audio.parts.utils.callbacks import SpeechEnhancementLoggingCallback

if isinstance(self._validation_dl, List):
data_loaders = self._validation_dl
else:
data_loaders = [self._validation_dl]
data_loaders = self._validation_dl if self._validation_dl else []

for data_loader_idx, data_loader in enumerate(data_loaders):
log_callbacks.append(
Expand Down
Loading
Loading