-
Notifications
You must be signed in to change notification settings - Fork 3.4k
[core] Unify validation_step_outputs to always return list-of-lists #15470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3d08f6f
edb505b
1e691bc
0194224
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,9 +62,9 @@ def _setup_loss(self): | |
|
|
||
| def _get_num_dataloaders(self, tag: str = 'val'): | ||
| if tag == 'val': | ||
| num_dataloaders = len(self._validation_dl) if isinstance(self._validation_dl, List) else 1 | ||
| num_dataloaders = len(self._validation_dl) if self._validation_dl else 1 | ||
| elif tag == 'test': | ||
| num_dataloaders = len(self._test_dl) if isinstance(self._test_dl, List) else 1 | ||
| num_dataloaders = len(self._test_dl) if self._test_dl else 1 | ||
| else: | ||
|
Comment on lines
63
to
68
|
||
| raise ValueError(f'Unexpected tag {tag}.') | ||
|
|
||
|
|
@@ -144,18 +144,12 @@ def on_test_start(self): | |
|
|
||
| def validation_step(self, batch, batch_idx, dataloader_idx: int = 0): | ||
| output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'val') | ||
| if isinstance(self.trainer.val_dataloaders, (list, tuple)) and len(self.trainer.val_dataloaders) > 1: | ||
| self.validation_step_outputs[dataloader_idx].append(output_dict) | ||
| else: | ||
| self.validation_step_outputs.append(output_dict) | ||
| self.validation_step_outputs[dataloader_idx].append(output_dict) | ||
| return output_dict | ||
|
|
||
| def test_step(self, batch, batch_idx, dataloader_idx=0): | ||
| output_dict = self.evaluation_step(batch, batch_idx, dataloader_idx, 'test') | ||
| if isinstance(self.trainer.test_dataloaders, (list, tuple)) and len(self.trainer.test_dataloaders) > 1: | ||
| self.test_step_outputs[dataloader_idx].append(output_dict) | ||
| else: | ||
| self.test_step_outputs.append(output_dict) | ||
| self.test_step_outputs[dataloader_idx].append(output_dict) | ||
| return output_dict | ||
|
|
||
| def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'): | ||
|
|
@@ -514,10 +508,7 @@ def configure_callbacks(self): | |
| log_callbacks = [] | ||
| from nemo.collections.audio.parts.utils.callbacks import SpeechEnhancementLoggingCallback | ||
|
|
||
| if isinstance(self._validation_dl, List): | ||
| data_loaders = self._validation_dl | ||
| else: | ||
| data_loaders = [self._validation_dl] | ||
| data_loaders = self._validation_dl if self._validation_dl else [] | ||
|
|
||
| for data_loader_idx, data_loader in enumerate(data_loaders): | ||
| log_callbacks.append( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
technically I think you want
and similar everywhere else - make it clear this is a None check