diff --git a/docs/source/tts/magpietts-longform.rst b/docs/source/tts/magpietts-longform.rst index a6f93b8f3282..33aef42a5abe 100644 --- a/docs/source/tts/magpietts-longform.rst +++ b/docs/source/tts/magpietts-longform.rst @@ -68,7 +68,7 @@ The input text is split into individual sentences using punctuation markers (``. Step 2: State Initialization ---------------------------- -A ``LongformChunkState`` object is created to track information across sentence chunks: +A ``ChunkState`` object is created to track information across sentence chunks: - **History text tokens**: Text from previous chunks for context - **History encoder context**: Encoder outputs that provide continuity @@ -112,7 +112,7 @@ Key Components 1. **Sentence Splitting** (``split_by_sentence``): Intelligently splits text on sentence boundaries while handling abbreviations (e.g., "Dr.", "Mr."). -2. **Chunk State** (``LongformChunkState``): Maintains context across chunks: +2. **Chunk State** (``ChunkState``): Maintains context across chunks: - ``history_text``: Text tokens from previous chunks - ``history_context_tensor``: Encoder outputs for continuity @@ -211,24 +211,24 @@ Configuration Dataclasses ######################### -``LongformConfig`` ------------------- +``ChunkedInferenceConfig`` +-------------------------- Immutable tuning parameters (set in model): .. literalinclude:: ../../../nemo/collections/tts/models/magpietts.py :language: python - :pyobject: LongformConfig + :pyobject: ChunkedInferenceConfig -``LongformChunkState`` ----------------------- +``ChunkState`` +-------------- Mutable state passed between chunk iterations: .. literalinclude:: ../../../nemo/collections/tts/models/magpietts.py :language: python - :pyobject: LongformChunkState + :pyobject: ChunkState Best Practices diff --git a/docs/source/tts/magpietts-po.rst b/docs/source/tts/magpietts-po.rst index 41436f32f3b7..8f987d784b8f 100644 --- a/docs/source/tts/magpietts-po.rst +++ b/docs/source/tts/magpietts-po.rst @@ -96,8 +96,8 @@ The final step is fine-tuning the base model on the preference pairs using the D max_epochs=10 \ exp_manager.exp_dir=/path/to/dpo_experiment \ exp_manager.checkpoint_callback_params.always_save_nemo=false \ - model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ - model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ + model.train_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ + model.validation_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +train_ds_meta.dpopreftrain.manifest_path="/path/to/manifests/" \ +train_ds_meta.dpopreftrain.audio_dir="/" \ +train_ds_meta.dpopreftrain.feature_dir="/" \ diff --git a/examples/tts/conf/magpietts/magpietts.yaml b/examples/tts/conf/magpietts/magpietts.yaml index 4c45f38fb4b3..6d0b9a7cd3b7 100644 --- a/examples/tts/conf/magpietts/magpietts.yaml +++ b/examples/tts/conf/magpietts/magpietts.yaml @@ -80,7 +80,7 @@ model: # pretrained_model: "google/byt5-small" train_ds: - dataset: + datasets: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${train_ds_meta} weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch} @@ -93,8 +93,11 @@ model: drop_last: true pin_memory: true + # Non-lhotse validation uses a single dataloader. All dataset_meta entries are mixed + # together, so validation metrics are logged jointly. For per-dataset validation + # metrics, use the lhotse config (magpietts_lhotse.yaml) with separate datasets entries. validation_ds: - dataset: + datasets: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${val_ds_meta} min_duration: 0.2 diff --git a/examples/tts/conf/magpietts/magpietts_lhotse.yaml b/examples/tts/conf/magpietts/magpietts_lhotse.yaml index b90b198ae7be..5f2bf32f33bf 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse.yaml @@ -1,6 +1,7 @@ name: Magpie-TTS quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration. + model: use_lhotse: true model_type: "decoder_ce" # decoder_context_tts or decoder_ce @@ -16,7 +17,7 @@ model: alignment_loss_scale: 0.002 embedding_dim: 768 codecmodel_path: ??? - cfg_unconditional_prob: 0.1 + cfg_unconditional_prob: 0.1 # enable classifier-free guidance during traing by dropping out conditionals with this probability # Alignment encoder parameters, to binarize the prior # This is used for attention-constrained training and inference @@ -70,57 +71,60 @@ model: train_ds: use_lhotse: ${model.use_lhotse} volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration : ??? # in seconds. Adjust based on your GPU memory. - quadratic_duration: ${quadratic_duration} - use_bucketing: true - num_buckets: 20 - bucket_buffer_size: 20_000 - shuffle_buffer_size: 20_000 - num_cuts_for_bins_estimate: 20_000 - shard_seed: "trng" - drop_last: true - shuffle: true - num_workers: 6 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 20_000 + shuffle_buffer_size: 20_000 + num_cuts_for_bins_estimate: 20_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 6 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] validation_ds: + # the entries under 'datasets' are a list of separate dataloaders. + # The structure is: + # - name: '' + # + # They inherit all settings from validation_ds, but can individually override them. use_lhotse: ${model.use_lhotse} volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. - quadratic_duration: ${quadratic_duration} - use_bucketing: false - force_finite: true - force_map_dataset: true - seed: 42 - shard_seed: "randomized" - drop_last: false - shuffle: false - num_workers: 2 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + force_map_dataset: true + seed: 42 + shard_seed: "randomized" + drop_last: false + shuffle: false + num_workers: 2 + pin_memory: true + + datasets: + - name: "val_set_0" # rename to your dataset name, add more as needed + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] encoder: n_layers: 6 @@ -185,10 +189,9 @@ trainer: precision: 32 max_steps: ??? accumulate_grad_batches: 1 - enable_checkpointing: False # Provided by exp_manager - logger: false # Provided by exp_manager + enable_checkpointing: false # Provided by exp_manager + logger: false # Provided by exp_manager log_every_n_steps: 100 - check_val_every_n_epoch: 1 limit_train_batches: 1_000 val_check_interval: 1_000 num_sanity_val_steps: 0 diff --git a/examples/tts/conf/magpietts/magpietts_lhotse_moe.yaml b/examples/tts/conf/magpietts/magpietts_lhotse_moe.yaml index 009eb02f2e46..6cb735a563a4 100644 --- a/examples/tts/conf/magpietts/magpietts_lhotse_moe.yaml +++ b/examples/tts/conf/magpietts/magpietts_lhotse_moe.yaml @@ -74,57 +74,60 @@ model: train_ds: use_lhotse: ${model.use_lhotse} volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration : ??? # in seconds. Adjust based on your GPU memory. - quadratic_duration: ${quadratic_duration} - use_bucketing: true - num_buckets: 20 - bucket_buffer_size: 20_000 - shuffle_buffer_size: 20_000 - num_cuts_for_bins_estimate: 20_000 - shard_seed: "trng" - drop_last: true - shuffle: true - num_workers: 6 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # in seconds. Adjust based on your GPU memory. + quadratic_duration: ${quadratic_duration} + use_bucketing: true + num_buckets: 20 + bucket_buffer_size: 20_000 + shuffle_buffer_size: 20_000 + num_cuts_for_bins_estimate: 20_000 + shard_seed: "trng" + drop_last: true + shuffle: true + num_workers: 6 + pin_memory: true + + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] validation_ds: + # the entries under 'datasets' are a list of separate dataloaders. + # The structure is: + # - name: '' + # + # They inherit all settings from validation_ds, but can individually override them. use_lhotse: ${model.use_lhotse} volume_norm: true - - dataset: - min_duration: 0.2 - min_context_speaker_similarity: 0.6 - max_cer: 0.03 - batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. - quadratic_duration: ${quadratic_duration} - use_bucketing: false - force_finite: true - force_map_dataset: true - seed: 42 - shard_seed: "randomized" - drop_last: false - shuffle: false - num_workers: 2 - pin_memory: true - - input_cfg: - - type: lhotse_shar - shar_path: ??? - weight: 1.0 - tags: - tokenizer_names: ["english_phoneme"] + min_duration: 0.2 + min_context_speaker_similarity: 0.6 + max_cer: 0.03 + batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset. + quadratic_duration: ${quadratic_duration} + use_bucketing: false + force_finite: true + force_map_dataset: true + seed: 42 + shard_seed: "randomized" + drop_last: false + shuffle: false + num_workers: 2 + pin_memory: true + + datasets: + - name: "val_set_0" # rename to your dataset name, add more as needed + input_cfg: + - type: lhotse_shar + shar_path: ??? + weight: 1.0 + tags: + tokenizer_names: ["english_phoneme"] encoder: n_layers: 6 diff --git a/examples/tts/conf/magpietts/magpietts_po_inference.yaml b/examples/tts/conf/magpietts/magpietts_po_inference.yaml index 735e750a899e..27bfee33656d 100644 --- a/examples/tts/conf/magpietts/magpietts_po_inference.yaml +++ b/examples/tts/conf/magpietts/magpietts_po_inference.yaml @@ -88,7 +88,7 @@ model: # pretrained_model: "google/byt5-small" test_ds: - dataset: + datasets: _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset dataset_meta: ${test_ds_meta} min_duration: 0.2 diff --git a/nemo/collections/tts/losses/moe_loss.py b/nemo/collections/tts/losses/moe_loss.py index 40e16697bafb..63267cfa1eeb 100644 --- a/nemo/collections/tts/losses/moe_loss.py +++ b/nemo/collections/tts/losses/moe_loss.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from nemo.core.classes import Loss, typecheck -from nemo.core.neural_types.elements import LossType, ProbsType +from nemo.core.neural_types.elements import LogitsType, LossType, ProbsType from nemo.core.neural_types.neural_type import NeuralType @@ -122,7 +122,7 @@ def __init__(self, loss_scale: float = 0.001): @property def input_types(self): return { - "router_logits": NeuralType(('B', 'T', 'D'), ProbsType()), # D = num_experts + "router_logits": NeuralType(('B', 'T', 'D'), LogitsType()), # D = num_experts "x_mask": NeuralType(('B', 'T'), ProbsType(), optional=True), } @@ -194,7 +194,7 @@ def __init__( @property def input_types(self): return { - "router_logits": NeuralType(('B', 'T', 'D'), ProbsType()), + "router_logits": NeuralType(('B', 'T', 'D'), LogitsType()), "router_probs": NeuralType(('B', 'T', 'D'), ProbsType()), "x_mask": NeuralType(('B', 'T'), ProbsType(), optional=True), } diff --git a/nemo/collections/tts/models/magpietts.py b/nemo/collections/tts/models/magpietts.py index 368b17a81eec..4d34471af5a1 100644 --- a/nemo/collections/tts/models/magpietts.py +++ b/nemo/collections/tts/models/magpietts.py @@ -17,8 +17,10 @@ import random import re import time + from dataclasses import dataclass, field, fields from functools import partial +from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, Union import numpy as np @@ -26,9 +28,10 @@ import torch import wandb from hydra.utils import instantiate +from lhotse.serialization import load_yaml +from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger - -from omegaconf import DictConfig, OmegaConf, open_dict +from omegaconf import DictConfig, ListConfig, OmegaConf, open_dict from torch import nn from torch.utils.data import get_worker_info @@ -51,6 +54,7 @@ binarize_attention_parallel, get_mask_from_lengths, plot_alignment_to_numpy, + plot_expert_usage_heatmap_to_numpy, ) from nemo.collections.tts.parts.utils.tts_dataset_utils import ( chunk_text_for_inference, @@ -630,6 +634,11 @@ def __init__(self, cfg: DictConfig, trainer: 'Trainer' = None): f"Each expert has d_ffn={cfg.decoder.d_ffn}. " f"Loss scales: router_load_balancing={router_load_balancing_loss_coeff}, router_z={router_z_loss_coeff}" ) + # Training-side accumulator for layer-wise expert usage heatmap. + # Accumulated every training_step, rendered + reset at each validation interval. + self._moe_num_experts = num_experts + self._moe_train_layer_usage_accum: Optional[torch.Tensor] = None # (n_layers, num_experts) + self._moe_train_accum_steps: int = 0 # Define cfg parameters into self parameters self.prior_end_step = self.cfg.prior_end_step @@ -814,6 +823,24 @@ def num_baked_speakers(self) -> int: return 0 return self.baked_context_embedding.num_embeddings + @property + def validation_step_outputs(self): + """Always use list-of-lists structure for uniform single/multi-dataloader handling. + + Overrides ModelPT which uses a flat list for single dataloader and list-of-lists + for multiple dataloaders. This override always returns list-of-lists so that + validation_step, on_validation_epoch_end, etc. don't need conditional branching. + """ + if self._validation_step_outputs is not None: + return self._validation_step_outputs + num_dl = len(self._validation_dl) if self._validation_dl is not None else 1 + self._validation_step_outputs = [[] for _ in range(num_dl)] + return self._validation_step_outputs + + @validation_step_outputs.setter + def validation_step_outputs(self, value): + self._validation_step_outputs = value + def _normalize_speaker_indices( self, speaker_indices: Optional[Union[int, List[int], torch.Tensor]], @@ -1789,128 +1816,223 @@ def sample_codes_from_logits( all_preds = torch.stack(all_preds, dim=2) # (B, num_codebooks, frame_stacking_factor) return all_preds - def log_attention_probs(self, attention_prob_matrix, audio_codes_lens, text_lens, prefix="", dec_context_size=0): - # attention_prob_matrix List of (B, C, audio_timesteps, text_timesteps) - wandb_images_log = {} + def _prepare_attention_images( + self, + attention_prob_matrix: List[torch.Tensor], + audio_codes_lens: torch.Tensor, + text_lens: torch.Tensor, + dec_context_size: int = 0, + max_examples: int = 3, + ) -> List[np.ndarray]: + """ + Convert attention probability matrices to numpy images for logging. + + Args: + attention_prob_matrix: List of attention tensors, each (B, H, audio_timesteps, text_timesteps). + audio_codes_lens: Audio sequence lengths per example. + text_lens: Text sequence lengths per example. + dec_context_size: Number of context audio frames to skip in attention visualization. + max_examples: Maximum number of examples to generate images for. + Returns: + List of numpy arrays in HWC format, one per example. + """ with torch.no_grad(): + # Concatenate attention heads and average attention_prob_matrix = torch.cat(attention_prob_matrix, dim=1) # (B, C, audio_timesteps, text_timesteps) attention_prob_matrix_mean = attention_prob_matrix.mean(dim=1) # (B, audio_timesteps, text_timesteps) - for logger in self.loggers: - is_wandb = isinstance(logger, WandbLogger) - is_tb = isinstance(logger, TensorBoardLogger) - if not is_wandb and not is_tb: - raise ValueError( - f"Invalid logger type for image logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." - ) + images = [] + num_examples = min(max_examples, attention_prob_matrix_mean.size(0)) + for idx in range(num_examples): + # Slice attention matrix to valid region (excluding context frames) + audio_len = int(audio_codes_lens[idx]) + text_len = int(text_lens[idx]) + item_attn_matrix = attention_prob_matrix_mean[idx][ + dec_context_size : dec_context_size + audio_len, :text_len + ] + item_attn_matrix = item_attn_matrix.detach().cpu().numpy() + img_np = plot_alignment_to_numpy(item_attn_matrix.T) + images.append(img_np) - wandb_images_log[f"Image/{prefix}/attention_matrix"] = list() - for idx in range(min(3, attention_prob_matrix_mean.size(0))): - item_attn_matrix = attention_prob_matrix_mean[idx][ - dec_context_size : dec_context_size + audio_codes_lens[idx], : text_lens[idx] - ] - item_attn_matrix = item_attn_matrix.detach().cpu().numpy() - img_np = plot_alignment_to_numpy(item_attn_matrix.T) + return images - if is_wandb: - wandb_images_log[f"Image/{prefix}/attention_matrix"].append( - wandb.Image(img_np, caption=f"Example_{idx}") - ) + def _prepare_audio_examples( + self, + logits: torch.Tensor, + target_audio_codes: torch.Tensor, + audio_codes_lens: torch.Tensor, + context_audio_codes: Optional[torch.Tensor] = None, + context_audio_codes_lens: Optional[torch.Tensor] = None, + max_examples: int = 3, + ) -> Dict[str, List[Optional[np.ndarray]]]: + """ + Decode audio codes to waveforms and convert to numpy arrays for logging. - if is_tb: - logger.experiment.add_image( - f'{prefix}/attention_matrix/Example_{idx}', - img_np, - global_step=self.global_step, - dataformats="HWC", - ) + Args: + logits: Model output logits to convert to predicted audio. + target_audio_codes: Ground truth audio codes. + audio_codes_lens: Lengths of target audio codes. + context_audio_codes: Optional context audio codes for voice cloning. + context_audio_codes_lens: Lengths of context audio codes. + max_examples: Maximum number of examples to process. - return wandb_images_log + Returns: + Dict with keys 'pred_audios', 'target_audios', 'context_audios', + each containing a list of numpy arrays (or None for context if unavailable). + """ + with torch.no_grad(): + # Decode predictions: convert logits to codes, remove EOS token, then decode to audio + pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens) + pred_audio_codes, pred_audio_codes_lens = self.remove_eos_token( + codes=pred_audio_codes, codes_len=audio_codes_lens + ) + pred_audio, pred_audio_lens, _ = self.codes_to_audio(pred_audio_codes, pred_audio_codes_lens) + + # Decode targets: remove EOS token, then decode to audio + target_audio_codes, target_audio_codes_lens = self.remove_eos_token( + codes=target_audio_codes, codes_len=audio_codes_lens + ) + target_audio, target_audio_lens, _ = self.codes_to_audio(target_audio_codes, target_audio_codes_lens) + + # Decode context audio if available (shape check ensures it's not a dummy tensor used in text context) + # This does not handle the case in which a batch has a mixture of text and audio context examples + context_audio, context_audio_lens = None, None + if context_audio_codes is not None and context_audio_codes.shape[2] > 3: + context_audio_codes, context_audio_codes_lens = self.remove_special_tokens( + codes=context_audio_codes, codes_len=context_audio_codes_lens + ) + context_audio, context_audio_lens, _ = self.codes_to_audio( + context_audio_codes, context_audio_codes_lens + ) + + pred_audios = [] + target_audios = [] + context_audios = [] + + num_examples = min(max_examples, pred_audio.size(0)) + for idx in range(num_examples): + # Convert to numpy and trim to actual length + pred_audio_np = pred_audio[idx, : pred_audio_lens[idx]].float().cpu().numpy() + target_audio_np = target_audio[idx, : target_audio_lens[idx]].float().cpu().numpy() + + pred_audios.append(pred_audio_np) + target_audios.append(target_audio_np) + + if context_audio is not None: + context_audio_np = context_audio[idx, : context_audio_lens[idx]].float().cpu().numpy() + context_audios.append(context_audio_np) + else: + context_audios.append(None) - def log_val_audio_example( + return { + 'pred_audios': pred_audios, + 'target_audios': target_audios, + 'context_audios': context_audios, + } + + def _collect_wandb_media_and_log_tb( self, - logits, - target_audio_codes, - audio_codes_lens, - context_audio_codes=None, - context_audio_codes_lens=None, - ): - wandb_audio_log = {} + *, + dataset_prefix: str, + pred_audios: List[np.ndarray], + target_audios: List[np.ndarray], + context_audios: List[Optional[np.ndarray]], + attention_data: Dict[str, List[np.ndarray]], + global_step: int, + ) -> Dict[str, Any]: + """ + Collect WandB media entries and log audio/attention to TensorBoard. - pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens) - pred_audio_codes, audio_codes_lens_pred = self.remove_eos_token( - codes=pred_audio_codes, codes_len=audio_codes_lens - ) - pred_audio, pred_audio_lens, _ = self.codes_to_audio(pred_audio_codes, audio_codes_lens_pred) + TensorBoard logging happens directly within this method. + WandB media is returned as a dict to be merged with other WandB media + (e.g., MoE heatmaps) into a single wandb.log() call by the caller, + ensuring all media shares the same WandB step index. - target_audio_codes, audio_codes_lens_target = self.remove_eos_token( - codes=target_audio_codes, codes_len=audio_codes_lens - ) - target_audio, target_audio_lens, _ = self.codes_to_audio(target_audio_codes, audio_codes_lens_target) + Args: + dataset_prefix: Prefix for log keys (e.g., 'val', 'val_set_0'). + pred_audios: List of predicted audio waveforms as numpy arrays. + target_audios: List of target audio waveforms as numpy arrays. + context_audios: List of context audio waveforms (or None per entry if unavailable). + attention_data: Dict mapping attention names to lists of numpy images. + global_step: Current training step for logging. - context_audio, context_audio_lens = None, None - if context_audio_codes is not None and context_audio_codes.shape[2] > 3: - context_audio_codes, context_audio_codes_lens = self.remove_special_tokens( - codes=context_audio_codes, codes_len=context_audio_codes_lens - ) - # > 3 ensures, it is a valid context audio tensor (and not dummy tensor used in text context) - # This does not handle the case in which a batch has a mixture of text and audio context examples - context_audio, context_audio_lens, _ = self.codes_to_audio(context_audio_codes, context_audio_codes_lens) + Returns: + Dict of WandB-ready media entries (audio + attention images). + Empty dict if no WandB logger is configured. + """ + wandb_media: Dict[str, Any] = {} for logger in self.loggers: is_wandb = isinstance(logger, WandbLogger) is_tb = isinstance(logger, TensorBoardLogger) if not is_wandb and not is_tb: raise ValueError( - f"Invalid logger type for audio logging: {type(logger)}. Only `WandbLogger` and `TensorBoardLogger` are supported." + f"Unsupported logger type: {type(logger)}. " + f"Only WandbLogger and TensorBoardLogger are supported for media logging." ) - for idx in range(min(3, pred_audio.size(0))): - pred_audio_np = pred_audio[idx].float().detach().cpu().numpy() - target_audio_np = target_audio[idx].float().detach().cpu().numpy() - pred_audio_np = pred_audio_np[: pred_audio_lens[idx]] - target_audio_np = target_audio_np[: target_audio_lens[idx]] - context_audio_np = None - if context_audio is not None: - context_audio_np = context_audio[idx].float().detach().cpu().numpy() - context_audio_np = context_audio_np[: context_audio_lens[idx]] - + for idx, (pred_audio_np, target_audio_np, context_audio_np) in enumerate( + zip(pred_audios, target_audios, context_audios) + ): if is_wandb: - wandb_audio_log[f"Audio/Example_{idx}"] = list() + audio_list = [] if context_audio_np is not None and context_audio_np.shape[0] > 0: - wandb_audio_log[f"Audio/Example_{idx}"].append( + audio_list.append( wandb.Audio(context_audio_np, sample_rate=self.output_sample_rate, caption="context") ) - wandb_audio_log[f"Audio/Example_{idx}"].append( + audio_list.append( wandb.Audio(pred_audio_np, sample_rate=self.output_sample_rate, caption="prediction") ) - wandb_audio_log[f"Audio/Example_{idx}"].append( + audio_list.append( wandb.Audio(target_audio_np, sample_rate=self.output_sample_rate, caption="target") ) + wandb_media[f"Audio:{dataset_prefix}/Example_{idx:02d}"] = audio_list if is_tb: if context_audio_np is not None and context_audio_np.shape[0] > 0: logger.experiment.add_audio( - f'Example_{idx}/context', + f'{dataset_prefix}/Example_{idx}/context', context_audio_np, - global_step=self.global_step, + global_step=global_step, sample_rate=self.output_sample_rate, ) logger.experiment.add_audio( - f'Example_{idx}/prediction', + f'{dataset_prefix}/Example_{idx}/prediction', pred_audio_np, - global_step=self.global_step, + global_step=global_step, sample_rate=self.output_sample_rate, ) logger.experiment.add_audio( - f'Example_{idx}/target', + f'{dataset_prefix}/Example_{idx}/target', target_audio_np, - global_step=self.global_step, + global_step=global_step, sample_rate=self.output_sample_rate, ) - return wandb_audio_log + # Log attention images + for attn_key, images in attention_data.items(): + # Determine log prefix: 'overall' uses dataset_prefix directly, others are nested + if attn_key == 'overall': + prefix = dataset_prefix + else: + prefix = f"{dataset_prefix}/{attn_key}" + + if is_wandb: + wandb_media[f"Image:{prefix}/attention_matrix"] = [ + wandb.Image(img_np, caption=f"Example_{idx:02d}") for idx, img_np in enumerate(images) + ] + + if is_tb: + for idx, img_np in enumerate(images): + logger.experiment.add_image( + f'{prefix}/attention_matrix/Example_{idx:02d}', + img_np, + global_step=global_step, + dataformats="HWC", + ) + + return wandb_media def scale_prior(self, prior, global_step): if prior is None: @@ -2676,15 +2798,21 @@ def process_batch(self, batch): x_mask=merged_mask, ) - # Compute expert usage statistics (averaged across all layers, batches, and valid tokens) - # This shows which experts are being used most frequently + # Compute expert usage statistics with torch.no_grad(): - # Use shared utility function for computing expert usage - expert_usage = compute_expert_usage(merged_probs, merged_mask) # (num_experts,) + num_experts = stacked_probs.size(-1) + n_moe_layers = stacked_probs.size(0) + + # Per-layer expert usage: (n_layers, num_experts) + layer_expert_usage = torch.stack( + [compute_expert_usage(stacked_probs[i], audio_codes_mask) for i in range(n_moe_layers)] + ) + + # Global expert usage: mean across layers (for scalar logging) + expert_usage = layer_expert_usage.mean(dim=0) # (num_experts,) # Compute how often each expert is selected in top-k # For padded positions, expert_indices=-1, so they don't match any valid expert (0 to num_experts-1) - num_experts = merged_probs.size(-1) expert_selection_counts = torch.zeros(num_experts, device=merged_probs.device) for expert_idx in range(num_experts): expert_selection_counts[expert_idx] = (merged_indices == expert_idx).float().sum() @@ -2694,17 +2822,12 @@ def process_batch(self, batch): valid_selections = (merged_indices != -1).sum().float().clamp_min(1.0) expert_selection_freq = expert_selection_counts / valid_selections - # Compute load balance metrics - batch_expert_usage_variance = expert_usage.var() - batch_expert_usage_max = expert_usage.max() - batch_expert_usage_min = expert_usage.min() - moe_expert_usage_stats = { - 'expert_usage': expert_usage.cpu(), # (num_experts,) - 'expert_selection_freq': expert_selection_freq.cpu(), # (num_experts,) - 'batch_expert_usage_variance': batch_expert_usage_variance.item(), - 'batch_expert_usage_max': batch_expert_usage_max.item(), - 'batch_expert_usage_min': batch_expert_usage_min.item(), + 'expert_usage': expert_usage.detach(), # (num_experts,) + 'layer_expert_usage': layer_expert_usage.detach(), # (n_layers, num_experts) + 'expert_selection_freq': expert_selection_freq.detach(), # (num_experts,) + 'batch_expert_usage_variance': expert_usage.var().detach(), + 'ideal_usage': 1.0 / num_experts, } # Add MoE loss to total loss (only in training mode) @@ -2739,34 +2862,54 @@ def training_step(self, batch, batch_idx): batch_output = self.process_batch(batch) loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] - self.log('train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) + self.log('Loss:train/codebook_loss', codebook_loss, prog_bar=True, sync_dist=True) if self.cfg_unconditional_prob == 0.0: # Only log alignment loss when not using cfg to avoid sync issues when # alignment loss is None on some ranks alignment_loss = batch_output['alignment_loss'] if alignment_loss is not None: - self.log('train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) - self.log('train/loss', loss, prog_bar=True, sync_dist=True) + self.log('Loss:train/alignment_loss', alignment_loss, prog_bar=True, sync_dist=True) + self.log('Loss:train/loss', loss, prog_bar=True, sync_dist=True) local_transformer_loss = batch_output['local_transformer_loss'] if local_transformer_loss is not None: - self.log('train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) + self.log('Loss:train/local_transformer_loss', local_transformer_loss, prog_bar=True, sync_dist=True) - # Log MoE losses if MoE is enabled + # Log MoE losses and expert usage if MoE is enabled moe_load_balancing_loss = batch_output.get('moe_load_balancing_loss', None) moe_router_z_loss = batch_output.get('moe_router_z_loss', None) - if moe_load_balancing_loss is not None: - self.log('train/moe_load_balancing_loss', moe_load_balancing_loss, prog_bar=True, sync_dist=True) - if moe_router_z_loss is not None: - self.log('train/moe_router_z_loss', moe_router_z_loss, prog_bar=True, sync_dist=True) + moe_expert_usage_stats = batch_output.get('moe_expert_usage_stats', None) + if moe_load_balancing_loss is not None and self.moe_auxiliary_loss.load_balancing_loss.loss_scale > 0: + self.log('Loss:train/moe_load_balancing_loss', moe_load_balancing_loss, prog_bar=True, sync_dist=True) + if moe_router_z_loss is not None and self.moe_auxiliary_loss.router_z_loss.loss_scale > 0: + self.log('Loss:train/moe_router_z_loss', moe_router_z_loss, prog_bar=True, sync_dist=True) + if moe_expert_usage_stats is not None: + expert_usage = moe_expert_usage_stats['expert_usage'] + layer_expert_usage = moe_expert_usage_stats['layer_expert_usage'] + + self.log( + 'Loss:train/moe_expert_usage_variance', + moe_expert_usage_stats['batch_expert_usage_variance'], + sync_dist=True, + ) + + # Per-expert usage scalars + for eidx in range(len(expert_usage)): + self.log(f'MoE:train/Expert_{eidx:02d}_usage', expert_usage[eidx], sync_dist=True) + + # Accumulate layer-wise usage for training heatmap + if self._moe_train_layer_usage_accum is None: + self._moe_train_layer_usage_accum = torch.zeros_like(layer_expert_usage) + self._moe_train_layer_usage_accum += layer_expert_usage.detach() + self._moe_train_accum_steps += 1 # Log batch info batch_size, text_token_max_len = batch["text"].shape text_token_total_num = batch["text_lens"].sum() batch_info_dict = { - "train/batch_size": batch_size, - "train/text_token_max_len": text_token_max_len, - "train/text_token_total_num_in_batch": text_token_total_num.item(), - "train/text_token_pad_ratio_percent_in_batch": 100 + "BatchInfo:train/batch_size": batch_size, + "BatchInfo:train/text_token_max_len": text_token_max_len, + "BatchInfo:train/text_token_total_num_in_batch": text_token_total_num.item(), + "BatchInfo:train/text_token_pad_ratio_percent_in_batch": 100 * (1 - text_token_total_num / (batch_size * text_token_max_len)), } @@ -2775,9 +2918,9 @@ def training_step(self, batch, batch_idx): audio_codes_total_num = batch["audio_codes_lens"].sum() batch_info_dict.update( { - "train/audio_codes_max_len": audio_codes_max_len, - "train/audio_codes_total_num_in_batch": audio_codes_total_num.item(), - "train/audio_codes_pad_ratio_percent_in_batch": 100 + "BatchInfo:train/audio_codes_max_len": audio_codes_max_len, + "BatchInfo:train/audio_codes_total_num_in_batch": audio_codes_total_num.item(), + "BatchInfo:train/audio_codes_pad_ratio_percent_in_batch": 100 * (1 - audio_codes_total_num / (batch_size * audio_codes_max_len)), } ) @@ -2786,9 +2929,9 @@ def training_step(self, batch, batch_idx): audio_samples_total_num = batch["audio_lens"].sum() batch_info_dict.update( { - "train/audio_samples_max_len": audio_samples_max_len, - "train/audio_samples_total_num_in_batch": audio_samples_total_num.item(), - "train/audio_samples_pad_ratio_percent_in_batch": 100 + "BatchInfo:train/audio_samples_max_len": audio_samples_max_len, + "BatchInfo:train/audio_samples_total_num_in_batch": audio_samples_total_num.item(), + "BatchInfo:train/audio_samples_pad_ratio_percent_in_batch": 100 * (1 - audio_samples_total_num / (batch_size * audio_samples_max_len)), } ) @@ -2797,14 +2940,30 @@ def training_step(self, batch, batch_idx): return loss - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): + """ + Validation step with support for multiple dataloaders. + + Args: + batch: Input batch + batch_idx: Batch index + dataloader_idx: Index of the dataloader (0 for single dataloader) + """ batch_output = self.process_batch(batch) # self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction # head. If we use local_transformer, then the local_transformer returns "local_transformer_logits" + loss = batch_output['loss'] codebook_loss = batch_output['codebook_loss'] alignment_loss = batch_output['alignment_loss'] aligner_encoder_loss = batch_output['aligner_encoder_loss'] + local_transformer_loss = batch_output['local_transformer_loss'] + + # Extract MoE losses and expert usage statistics if MoE is enabled + moe_load_balancing_loss = batch_output.get('moe_load_balancing_loss', None) + moe_router_z_loss = batch_output.get('moe_router_z_loss', None) + moe_expert_usage_stats = batch_output.get('moe_expert_usage_stats', None) + logits = batch_output['logits'] audio_codes_target = batch_output['audio_codes_target'] audio_codes_lens_target = batch_output['audio_codes_lens_target'] @@ -2814,152 +2973,100 @@ def validation_step(self, batch, batch_idx): text_lens = batch_output['text_lens'] dec_context_size = batch_output['dec_context_size'] - # Extract MoE losses and expert usage statistics if MoE is enabled - moe_load_balancing_loss = batch_output.get('moe_load_balancing_loss', None) - moe_router_z_loss = batch_output.get('moe_router_z_loss', None) - moe_expert_usage_stats = batch_output.get('moe_expert_usage_stats', None) - - if alignment_loss is None: - alignment_loss = torch.tensor(0.0, device=loss.device) - if aligner_encoder_loss is None: - aligner_encoder_loss = torch.tensor(0.0, device=loss.device) - if moe_load_balancing_loss is None: - moe_load_balancing_loss = torch.tensor(0.0, device=loss.device) - if moe_router_z_loss is None: - moe_router_z_loss = torch.tensor(0.0, device=loss.device) - - if batch_idx == 0 and self.global_rank == 0: - # Log MoE expert usage statistics to WandB (first batch only for visualization) - if self.use_moe and moe_expert_usage_stats is not None: - wandb_moe_first_batch_log = {} - - # Log per-expert usage as bar chart - expert_usage = moe_expert_usage_stats['expert_usage'].numpy() - expert_selection_freq = moe_expert_usage_stats['expert_selection_freq'].numpy() - - for logger in self.loggers: - if isinstance(logger, WandbLogger): - # Create bar chart for expert usage (routing probabilities) - expert_names = [f"Expert_{i}" for i in range(len(expert_usage))] - usage_data = [[name, usage] for name, usage in zip(expert_names, expert_usage)] - wandb_moe_first_batch_log['val/expert_usage_distribution'] = wandb.plot.bar( - wandb.Table(data=usage_data, columns=["Expert", "Usage"]), - "Expert", - "Usage", - title="Expert Usage (Routing Probabilities)", - ) - - # Create bar chart for expert selection frequency (top-k selections) - selection_data = [[name, freq] for name, freq in zip(expert_names, expert_selection_freq)] - wandb_moe_first_batch_log['val/expert_selection_frequency'] = wandb.plot.bar( - wandb.Table(data=selection_data, columns=["Expert", "Frequency"]), - "Expert", - "Frequency", - title=f"Expert Selection Frequency (Top-{self.decoder.top_k_experts})", - ) - - # Log scalar metrics for numerical tracking - wandb_moe_first_batch_log['val/batch_expert_usage_variance'] = moe_expert_usage_stats[ - 'batch_expert_usage_variance' - ] - wandb_moe_first_batch_log['val/batch_expert_usage_max'] = moe_expert_usage_stats[ - 'batch_expert_usage_max' - ] - wandb_moe_first_batch_log['val/batch_expert_usage_min'] = moe_expert_usage_stats[ - 'batch_expert_usage_min' - ] - - # Log individual expert usage percentages as scalars - for idx, usage in enumerate(expert_usage): - wandb_moe_first_batch_log[f'val/expert_{idx}_usage'] = float(usage) + val_output = { + 'val_loss': loss, + 'val_codebook_loss': codebook_loss, + } - logger.experiment.log(wandb_moe_first_batch_log) + # Only add optional losses if they were computed (not None) + if alignment_loss is not None: + val_output['val_alignment_loss'] = alignment_loss + if local_transformer_loss is not None: + val_output['val_local_transformer_loss'] = local_transformer_loss + if aligner_encoder_loss is not None: + val_output['val_aligner_encoder_loss'] = aligner_encoder_loss + if moe_load_balancing_loss is not None: + val_output['val_moe_load_balancing_loss'] = moe_load_balancing_loss + if moe_router_z_loss is not None: + val_output['val_moe_router_z_loss'] = moe_router_z_loss + if moe_expert_usage_stats is not None: + val_output['val_moe_expert_usage_stats'] = moe_expert_usage_stats + # Prepare media data for logging (only first batch of each dataloader, rank 0 only). if batch_idx == 0 and self.global_rank == 0: - # Prepare dictionary for aggregated wandb logging - wandb_log_dict = {} + dataset_prefix = self.get_validation_dataloader_prefix(dataloader_idx) + + # Prepare audio examples (decode via vocoder, convert to numpy) + audio_data = self._prepare_audio_examples( + logits=logits, + target_audio_codes=audio_codes_target, + audio_codes_lens=audio_codes_lens_target, + context_audio_codes=context_audio_codes, + context_audio_codes_lens=context_audio_codes_lens, + max_examples=3, + ) - # Get audio data for logging - wandb_log_dict.update( - self.log_val_audio_example( - logits, audio_codes_target, audio_codes_lens_target, context_audio_codes, context_audio_codes_lens - ) + # Prepare attention images (only when cross-attention is available) + attention_data = {} + has_cross_attn = ( + self.model_type != 'decoder_pretrain_synthesizer' + and len(attn_info[self.transcript_decoder_layers[0]].get('cross_attn_probabilities', [])) > 1 ) - # Get attention image data for logging - if len(attn_info[self.transcript_decoder_layers[0]]['cross_attn_probabilities']) > 1: - # cross_attn_probabilities only returned when not using flash attention + if has_cross_attn: + # Overall attention: average across CTC prior layers cross_attention_probs = [ attn['cross_attn_probabilities'][0] for layer_idx, attn in enumerate(attn_info) if layer_idx in self.ctc_prior_layer_ids ] - wandb_log_dict.update( - self.log_attention_probs( - cross_attention_probs, - audio_codes_lens_target, - text_lens, - prefix="val", - dec_context_size=dec_context_size, - ) + attention_data['overall'] = self._prepare_attention_images( + cross_attention_probs, + audio_codes_lens_target, + text_lens, + dec_context_size=dec_context_size, + max_examples=3, ) + # Per-layer attention visualization for layer_idx in self.transcript_decoder_layers: - cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]] - wandb_log_dict.update( - self.log_attention_probs( - cross_attention_probs, - audio_codes_lens_target, - text_lens, - prefix=f"val/layer_{layer_idx}", - dec_context_size=dec_context_size, - ) + layer_cross_attention_probs = [attn_info[layer_idx]['cross_attn_probabilities'][0]] + attention_data[f'layer_{layer_idx:02d}'] = self._prepare_attention_images( + layer_cross_attention_probs, + audio_codes_lens_target, + text_lens, + dec_context_size=dec_context_size, + max_examples=3, ) + # Aligner encoder attention (if available) if batch_output['aligner_attn_soft'] is not None: - wandb_log_dict.update( - self.log_attention_probs( - [batch_output['aligner_attn_soft']], - audio_codes_lens_target, - text_lens, - prefix="val/aligner_encoder_attn", - ) + attention_data['aligner_encoder_attn'] = self._prepare_attention_images( + [batch_output['aligner_attn_soft']], + audio_codes_lens_target, + text_lens, + dec_context_size=0, + max_examples=3, ) if batch_output['aligner_attn_hard'] is not None: - wandb_log_dict.update( - self.log_attention_probs( - [batch_output['aligner_attn_hard'].unsqueeze(1)], - audio_codes_lens_target, - text_lens, - prefix="val/aligner_encoder_attn_hard", - ) + attention_data['aligner_encoder_attn_hard'] = self._prepare_attention_images( + [batch_output['aligner_attn_hard'].unsqueeze(1)], + audio_codes_lens_target, + text_lens, + dec_context_size=0, + max_examples=3, ) - # Perform single wandb log call if wandb is active and there is data - for logger in self.loggers: - if isinstance(logger, WandbLogger) and wandb_log_dict: - logger.experiment.log(wandb_log_dict) - - local_transformer_loss = batch_output['local_transformer_loss'] - val_output = { - 'val_loss': loss, - 'val_codebook_loss': codebook_loss, - 'val_alignment_loss': alignment_loss, - 'val_local_transformer_loss': local_transformer_loss, - 'val_aligner_encoder_loss': aligner_encoder_loss, - 'val_moe_load_balancing_loss': moe_load_balancing_loss, - 'val_moe_router_z_loss': moe_router_z_loss, - } - - # Store expert usage stats for aggregation at epoch end - if moe_expert_usage_stats is not None: - # Store batch-level variance for aggregation at epoch end - val_output['val_batch_expert_usage_variance'] = torch.tensor( - moe_expert_usage_stats['batch_expert_usage_variance'], device=loss.device - ) + val_output['media_data'] = { + 'dataset_prefix': dataset_prefix, + 'pred_audios': audio_data['pred_audios'], + 'target_audios': audio_data['target_audios'], + 'context_audios': audio_data['context_audios'], + 'attention_data': attention_data, + } - self.validation_step_outputs.append(val_output) + self.validation_step_outputs[dataloader_idx].append(val_output) return val_output @@ -3567,19 +3674,204 @@ def test_step(self, batch, batch_idx): audio_path = os.path.join(audio_dir, f'predicted_audioRank{self.global_rank}_{item_idx}.wav') sf.write(audio_path, predicted_audio_np, self.output_sample_rate) + def multi_validation_epoch_end( + self, outputs: List[Dict[str, torch.Tensor]], dataloader_idx: int = 0 + ) -> Tuple[Dict[str, torch.Tensor], Optional[Dict[str, torch.Tensor]]]: + """ + Called for each validation dataloader at the end of validation epoch. + Computes metrics for this specific dataloader. + + Args: + outputs: List of outputs from validation_step for this specific dataloader + dataloader_idx: Index of the current dataloader + + Returns: + A tuple of (log_dict, moe_expert_data): + - log_dict: scalar metrics suitable for self.log() + - moe_expert_data: per-expert usage/selection_freq tensors of shape (num_experts,), or None + """ + + def collect_required_metric(outputs, key, dim=None): + values = [x[key] for x in outputs if key in x and x[key] is not None] + if len(values) == 0: + raise ValueError( + f"No valid values found for required metric '{key}' in validation outputs " + f"for dataloader {dataloader_idx}. This indicates an issue with validation." + ) + return torch.stack(values).mean(dim=dim) + + def collect_optional_metric(outputs, key, dim=None): + """Collect optional metric - returns None if not found.""" + values = [x[key] for x in outputs if key in x and x[key] is not None] + if len(values) == 0: + return None + return torch.stack(values).mean(dim=dim) + + if len(outputs) == 0: + raise ValueError( + f"No validation outputs for dataloader {dataloader_idx}. " + f"This indicates an issue with the validation dataloader or validation step." + ) + + # Compute required metrics + val_loss = collect_required_metric(outputs, 'val_loss') + val_codebook_loss = collect_required_metric(outputs, 'val_codebook_loss') + + log_dict = { + 'loss': val_loss, + 'codebook_loss': val_codebook_loss, + } + + # Compute optional metrics + VAL_OPTIONAL_METRICS = [ + 'val_alignment_loss', + 'val_aligner_encoder_loss', + 'val_local_transformer_loss', + 'val_moe_load_balancing_loss', + 'val_moe_router_z_loss', + ] + for metric_key in VAL_OPTIONAL_METRICS: + metric_value = collect_optional_metric(outputs, metric_key) + if metric_value is not None: + log_dict[metric_key.removeprefix('val_')] = metric_value + + # Exclude MoE metrics whose loss scale is disabled + if self.use_moe: + if self.moe_auxiliary_loss.load_balancing_loss.loss_scale <= 0: + log_dict.pop('moe_load_balancing_loss', None) + if self.moe_auxiliary_loss.router_z_loss.loss_scale <= 0: + log_dict.pop('moe_router_z_loss', None) + + # Collect per-expert usage vectors + val_moe_expert_usage_stats = [ + x.get('val_moe_expert_usage_stats') for x in outputs if x.get('val_moe_expert_usage_stats') is not None + ] + moe_expert_data = None + if len(val_moe_expert_usage_stats) > 0: + val_moe_expert_usage = collect_required_metric(val_moe_expert_usage_stats, 'expert_usage', dim=0) + val_moe_expert_selection_freq = collect_required_metric( + val_moe_expert_usage_stats, 'expert_selection_freq', dim=0 + ) + val_layer_expert_usage = collect_required_metric(val_moe_expert_usage_stats, 'layer_expert_usage', dim=0) + ideal_usage = val_moe_expert_usage_stats[0]['ideal_usage'] + moe_expert_data = { + 'moe_expert_usage': val_moe_expert_usage, + 'moe_expert_selection_freq': val_moe_expert_selection_freq, + 'layer_expert_usage': val_layer_expert_usage, + 'ideal_usage': ideal_usage, + } + + return log_dict, moe_expert_data + def on_validation_epoch_end(self): - collect = lambda key: torch.stack([x[key] for x in self.validation_step_outputs]).mean() - val_loss = collect("val_loss") - val_codebook_loss = collect("val_codebook_loss") - val_alignment_loss = collect("val_alignment_loss") - val_aligner_encoder_loss = collect("val_aligner_encoder_loss") - - # log val_loss in the same group as the other val metrics. - self.log("val/loss", val_loss, prog_bar=True, sync_dist=True) - # ensure val_loss is available for epoch-level checkpointing and filename generation without cluttering wandb logs. + """ + Computes and logs metrics across all validation dataloaders. + + Three-phase structure: + 1. Compute — aggregates metrics and collect media/heatmap data from all dataloaders. + 2. WandB media — logs all non-scalar media (audio, attention images, MoE heatmaps). + 3. Scalars — logs loss metrics and per-expert usage scalars. + """ + if len(self.validation_step_outputs) == 0: + return {} + + num_dataloaders = len(self.validation_step_outputs) + + # --- Phase 1: Compute all metrics + collect media data --- + all_moe_expert_data: List[Tuple[str, Dict[str, torch.Tensor]]] = [] + all_media_data: List[Dict[str, Any]] = [] + per_dl_logs: List[Tuple[str, Dict[str, torch.Tensor]]] = [] + aggregated_metrics: Dict[str, List[torch.Tensor]] = {} + + for dataloader_idx, val_outputs in enumerate(self.validation_step_outputs): + if len(val_outputs) == 0: + raise ValueError( + f"Validation dataloader {dataloader_idx} produced no outputs. " + f"Check that the dataset is not empty and validation_step is working correctly." + ) + + dataloader_logs, moe_expert_data = self.multi_validation_epoch_end( + val_outputs, dataloader_idx=dataloader_idx + ) + + dataloader_prefix = self.get_validation_dataloader_prefix(dataloader_idx) + per_dl_logs.append((dataloader_prefix, dataloader_logs)) + + if moe_expert_data is not None: + all_moe_expert_data.append((dataloader_prefix, moe_expert_data)) + + if len(val_outputs) > 0 and 'media_data' in val_outputs[0]: + all_media_data.append(val_outputs[0]['media_data']) + + for metric_name, metric_value in dataloader_logs.items(): + aggregated_metrics.setdefault(metric_name, []).append(metric_value) + + for idx in range(num_dataloaders): + self.validation_step_outputs[idx].clear() + + # Validate required metrics were collected + for required_metric in ['loss', 'codebook_loss']: + if required_metric not in aggregated_metrics or len(aggregated_metrics[required_metric]) == 0: + raise ValueError(f"No {required_metric} collected from any dataloader.") + + # --- Phase 2: Single WandB media log (rank 0 only) --- + if self.global_rank == 0: + global_step = int(self.global_step) + wandb_media: Dict[str, Any] = {} + + for media_data in all_media_data: + media_entries = self._collect_wandb_media_and_log_tb(**media_data, global_step=global_step) + wandb_media.update(media_entries) + + # heatmaps show layer×expert routing structure + if all_moe_expert_data: + for dataset_name, moe_data in all_moe_expert_data: + heatmap_np = plot_expert_usage_heatmap_to_numpy( + layer_expert_usage=moe_data['layer_expert_usage'].float().cpu().numpy(), + ideal_usage=moe_data['ideal_usage'], + title=f"MoE Expert Usage — {dataset_name} (step {int(self.global_step)})", + ) + wandb_media[f"MoE:{dataset_name}/expert_usage_heatmap"] = wandb.Image(heatmap_np) + + if self._moe_train_layer_usage_accum is not None and self._moe_train_accum_steps > 0: + avg_layer_usage = self._moe_train_layer_usage_accum / self._moe_train_accum_steps + heatmap_np = plot_expert_usage_heatmap_to_numpy( + layer_expert_usage=avg_layer_usage.float().cpu().numpy(), + ideal_usage=1.0 / self._moe_num_experts, + title=f"MoE Expert Usage — train ({self._moe_train_accum_steps} steps avg, step {int(self.global_step)})", + ) + wandb_media["MoE:train/expert_usage_heatmap"] = wandb.Image(heatmap_np) + + self._moe_train_layer_usage_accum.zero_() + self._moe_train_accum_steps = 0 + + if wandb_media: + for logger in self.loggers: + if isinstance(logger, WandbLogger): + logger.experiment.log(wandb_media, commit=False) + + # --- Phase 3: Scalar metrics --- + for dataloader_prefix, dataloader_logs in per_dl_logs: + for metric_name, metric_value in dataloader_logs.items(): + self.log( + f"Loss:{dataloader_prefix}/{metric_name}", + metric_value, + prog_bar=(num_dataloaders == 1), + sync_dist=True, + ) + + checkpoint_loss = aggregated_metrics['loss'][0] + if num_dataloaders > 1: + for metric_name, metric_values in aggregated_metrics.items(): + if "loss" in metric_name: + avg_value = torch.stack(metric_values).mean() + self.log(f"Loss:val_avg/{metric_name}", avg_value, prog_bar=True, sync_dist=True) + if metric_name == 'loss': + checkpoint_loss = avg_value + self.log( "val_loss", - val_loss, + checkpoint_loss, prog_bar=False, sync_dist=True, on_step=False, @@ -3587,50 +3879,29 @@ def on_validation_epoch_end(self): logger=False, enable_graph=False, ) - self.log("val/codebook_loss", val_codebook_loss, prog_bar=True, sync_dist=True) - self.log("val/alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) - self.log("val/aligner_encoder_loss", val_aligner_encoder_loss, prog_bar=True, sync_dist=True) - - if self.local_transformer_type != LocalTransformerType.NO_LT: - val_local_transformer_loss = collect("val_local_transformer_loss") - self.log("val/local_transformer_loss", val_local_transformer_loss, prog_bar=True, sync_dist=True) - # Log MoE losses and expert usage if MoE is enabled - if self.use_moe: - val_moe_load_balancing_loss = collect("val_moe_load_balancing_loss") - val_moe_router_z_loss = collect("val_moe_router_z_loss") - - # Log MoE losses - self.log("val/moe_load_balancing_loss", val_moe_load_balancing_loss, prog_bar=True, sync_dist=True) - self.log("val/moe_router_z_loss", val_moe_router_z_loss, prog_bar=True, sync_dist=True) - - # Log expert usage variance (averaged across all validation batches) - if any('val_batch_expert_usage_variance' in x for x in self.validation_step_outputs): - # This is the MEAN of batch-level variances across the epoch - val_epoch_mean_expert_usage_variance = collect("val_batch_expert_usage_variance") - self.log( - "val/expert_usage_variance_epoch_mean", - val_epoch_mean_expert_usage_variance, - prog_bar=False, - sync_dist=True, - ) + if all_moe_expert_data: + for dataset_name, moe_data in all_moe_expert_data: + expert_usage = moe_data['moe_expert_usage'] + expert_sel_freq = moe_data['moe_expert_selection_freq'] - # Log interpretation hints - # Ideal variance for N experts: 0 (perfectly balanced) - # High variance (>0.01) indicates imbalanced expert usage - num_experts = self.cfg.decoder.get('num_experts', 8) - ideal_usage = 1.0 / num_experts - logging.info( - f"MoE Expert Usage (Epoch Mean) - Ideal: {ideal_usage:.4f} per expert, " - f"Variance: {val_epoch_mean_expert_usage_variance:.6f} " - f"({'Balanced' if val_epoch_mean_expert_usage_variance < 0.01 else 'Imbalanced'})" - ) + for eidx in range(len(expert_usage)): + self.log(f'MoE:{dataset_name}/Expert_{eidx:02d}_usage', expert_usage[eidx], sync_dist=True) + self.log( + f'MoE:{dataset_name}/Expert_{eidx:02d}_selection_freq', expert_sel_freq[eidx], sync_dist=True + ) - self.validation_step_outputs.clear() # free memory + return {} def get_dataset(self, dataset_cfg, dataset_type): + if 'datasets' not in dataset_cfg or not isinstance(dataset_cfg.datasets, (dict, DictConfig)): + raise ValueError( + "Expected 'datasets' key (dict) in dataset config with _target_, dataset_meta, etc. " + f"Got keys: {list(dataset_cfg.keys())}" + ) + dataset = instantiate( - dataset_cfg.dataset, + dataset_cfg.datasets, sample_rate=self.sample_rate, bos_id=self.bos_id, eos_id=self.eos_id, @@ -3653,6 +3924,114 @@ def get_dataset(self, dataset_cfg, dataset_type): ) # This will be used in worker_init_fn for instantiating tokenizer return dataset + def setup_multiple_validation_data(self, val_data_config: Union[DictConfig, Dict]): + """ + Setup validation data with support for multiple datasets. + Overrides parent class to handle both non-lhotse and lhotse dataloaders. + + Non-lhotse config (datasets is a dict -- single dataloader, multiplicity via dataset_meta):: + + validation_ds: + datasets: + _target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset + dataset_meta: ... + min_duration: 0.2 + max_duration: 20.0 + dataloader_params: ... + + Note: Non-lhotse creates a single dataloader even when dataset_meta contains + multiple entries (e.g., ``{en: ..., es: ...}``). All datasets are mixed + in one dataloader, so validation metrics are logged jointly (e.g., + prefix ``"en+es"``) rather than per-dataset. For per-dataset validation + metrics, use the lhotse config with separate datasets list entries. + + Lhotse config (datasets is a list -- multiple dataloaders):: + + validation_ds: + use_lhotse: true + # ... shared settings ... + datasets: + - name: "val_set_0" + input_cfg: [...] or path to an external YAML file + - name: "val_set_1" + input_cfg: [...] or path to an external YAML file + """ + # Set placeholders that may be overridden + self._val_dl_idx: int = 0 + self._validation_names: Optional[List[str]] = None + self._validation_dl: Optional[torch.utils.data.DataLoader] = None + + # Preserve config + self._update_dataset_config(dataset_name='validation', config=val_data_config) + + if 'datasets' not in val_data_config: + raise ValueError( + "validation_ds config must contain a 'datasets' key. " + "For non-lhotse: a dict with _target_, dataset_meta, etc. " + "For lhotse: a list of dataset configurations. " + "See magpietts.yaml or magpietts_lhotse.yaml for examples." + ) + + datasets_value = val_data_config.datasets + + # Non-lhotse: datasets is a dict (single dataloader, multiplicity via dataset_meta) + if isinstance(datasets_value, (dict, DictConfig)): + dataset_meta = datasets_value.get('dataset_meta', {}) + if dataset_meta: + val_name = '+'.join(dataset_meta.keys()) + else: + val_name = 'val_set_0' + logging.info(f"Setting up single non-lhotse validation dataloader: '{val_name}'") + self._validation_names = [val_name] + self._validation_dl = [self._setup_test_dataloader(val_data_config)] + return + + # Lhotse: datasets is a path to an external YAML file (supports local paths and remote URLs like s3://) or a list + if isinstance(datasets_value, (str, Path)): + logging.info(f"Loading validation datasets from external file: {datasets_value}") + datasets_list = OmegaConf.create(load_yaml(datasets_value)) + elif isinstance(datasets_value, (list, ListConfig)): + datasets_list = datasets_value + else: + raise ValueError( + f"Lhotse 'datasets' in `validation_ds` must be a non-empty list of dataset configurations. " + f"Got: {type(datasets_value).__name__}" + ) + + if len(datasets_list) == 0: + raise ValueError("Lhotse 'datasets' in `validation_ds` must be a non-empty list.") + + logging.info(f"Setting up {len(datasets_list)} validation dataset(s)") + + dataloaders = [] + dataset_names = [] + + # Extract shared config (everything except 'datasets' key) + shared_config = OmegaConf.create(val_data_config) + shared_config.pop('datasets', None) + + for idx, dataset_config in enumerate(datasets_list): + merged_config = OmegaConf.merge(shared_config, dataset_config) + + if isinstance(dataset_config, (dict, DictConfig)) and 'name' in dataset_config: + dataset_name = dataset_config['name'] + else: + dataset_name = f"val_set_{idx}" + + dataset_names.append(dataset_name) + + # Remove 'name' field from config as it's not needed for dataloader setup + temp_config = OmegaConf.create(merged_config) + temp_config.pop('name', None) + + dataloader = self._setup_test_dataloader(temp_config) + dataloaders.append(dataloader) + logging.info(f" - Validation dataset {idx}: '{dataset_name}'") + + self._validation_names = dataset_names + self._validation_dl = dataloaders + logging.info(f"Successfully setup {len(dataloaders)} validation dataloader(s)") + def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.DataLoader: # TODO @xueyang: better to distinguish cfg. self.cfg is the model cfg, while cfg here is train_ds cfg. Also # cfg is a classifier-free guidance. @@ -3675,7 +4054,7 @@ def get_lhotse_dataloader(self, dataset_cfg, mode='train') -> torch.utils.data.D text_context_remapping_prob=self.text_context_remapping_prob, ) data_loader = get_lhotse_dataloader_from_config( - config=dataset_cfg.dataset, + config=dataset_cfg, global_rank=self.global_rank, world_size=self.world_size, dataset=dataset, @@ -3690,9 +4069,9 @@ def setup_training_data(self, dataset_cfg): # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. if not isinstance(dataset_cfg, DictConfig): dataset_cfg = OmegaConf.create(dataset_cfg) - OmegaConf.set_struct(dataset_cfg.dataset, False) - dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) - OmegaConf.set_struct(dataset_cfg.dataset, True) + OmegaConf.set_struct(dataset_cfg, False) + dataset_cfg.update({"sample_rate": self.sample_rate}) + OmegaConf.set_struct(dataset_cfg, True) self._train_dl = self.get_lhotse_dataloader(dataset_cfg, mode='train') else: @@ -3720,9 +4099,9 @@ def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: # specify target sampling rate the same as codec model's because lhotse config defaults 16_000. if not isinstance(dataset_cfg, DictConfig): dataset_cfg = OmegaConf.create(dataset_cfg) - OmegaConf.set_struct(dataset_cfg.dataset, False) - dataset_cfg.dataset.update({"sample_rate": self.sample_rate}) - OmegaConf.set_struct(dataset_cfg.dataset, True) + OmegaConf.set_struct(dataset_cfg, False) + dataset_cfg.update({"sample_rate": self.sample_rate}) + OmegaConf.set_struct(dataset_cfg, True) data_loader = self.get_lhotse_dataloader(dataset_cfg, mode='test') else: dataset = self.get_dataset(dataset_cfg, dataset_type='test') @@ -3742,7 +4121,9 @@ def _setup_test_dataloader(self, dataset_cfg) -> torch.utils.data.DataLoader: return data_loader def setup_validation_data(self, dataset_cfg): - self._validation_dl = self._setup_test_dataloader(dataset_cfg) + """Required by ModelPT (abstract). Use setup_multiple_validation_data instead.""" + self._validation_names = ['val_set_0'] + self._validation_dl = [self._setup_test_dataloader(dataset_cfg)] def setup_test_data(self, dataset_cfg): self._test_dl = self._setup_test_dataloader(dataset_cfg) diff --git a/nemo/collections/tts/models/magpietts_preference_optimization.py b/nemo/collections/tts/models/magpietts_preference_optimization.py index f943fc566286..d583cacadd74 100644 --- a/nemo/collections/tts/models/magpietts_preference_optimization.py +++ b/nemo/collections/tts/models/magpietts_preference_optimization.py @@ -432,7 +432,7 @@ def training_step(self, batch, batch_idx): self.log('train_sft_loss', dpo_outputs['sft_loss'], prog_bar=True, sync_dist=True) return dpo_outputs['loss'] - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): dpo_outputs = self.process_batch_dpo(batch) val_loss = dpo_outputs['loss'] @@ -440,7 +440,7 @@ def validation_step(self, batch, batch_idx): val_sft_loss = dpo_outputs['sft_loss'] val_alignment_loss = dpo_outputs['alignment_loss'] - self.validation_step_outputs.append( + self.validation_step_outputs[dataloader_idx].append( { 'val_loss': val_loss, 'val_pref_loss': val_pref_loss, @@ -452,11 +452,12 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): def collect(key): values = [] - for x in self.validation_step_outputs: - if x[key] is not None: - values.append(x[key]) - else: - values.append(torch.tensor(0.0, device=self.device)) + for val_outputs in self.validation_step_outputs: + for x in val_outputs: + if x[key] is not None: + values.append(x[key]) + else: + values.append(torch.tensor(0.0, device=self.device)) stacked_values = torch.stack(values) return stacked_values.mean() @@ -469,7 +470,8 @@ def collect(key): self.log("val_sft_loss", val_sft_loss, prog_bar=True, sync_dist=True) if val_alignment_loss is not None: self.log("val_alignment_loss", val_alignment_loss, prog_bar=True, sync_dist=True) - self.validation_step_outputs.clear() + for val_outputs in self.validation_step_outputs: + val_outputs.clear() class MagpieTTSModelOnlinePO(MagpieTTSModel): @@ -972,14 +974,14 @@ def training_step(self, batch, batch_idx): self.log('train_std_reward', po_outputs['std_reward'], prog_bar=True, sync_dist=True) return po_outputs['loss'] - def validation_step(self, batch, batch_idx): + def validation_step(self, batch, batch_idx, dataloader_idx=0): po_outputs = self.process_batch_online_po(batch, 1, mode='val') batch_metrics = po_outputs['batch_metrics'] mean_reward = po_outputs['mean_reward'] val_loss = po_outputs['loss'] val_kl_loss = po_outputs['kl_loss'] - self.validation_step_outputs.append( + self.validation_step_outputs[dataloader_idx].append( { 'mean_reward': mean_reward, 'std_reward': po_outputs['std_reward'], @@ -992,11 +994,12 @@ def validation_step(self, batch, batch_idx): def on_validation_epoch_end(self): def collect(key): values = [] - for x in self.validation_step_outputs: - if x[key] is not None: - values.append(x[key]) - else: - values.append(torch.tensor(0.0, device=self.device)) + for val_outputs in self.validation_step_outputs: + for x in val_outputs: + if x[key] is not None: + values.append(x[key]) + else: + values.append(torch.tensor(0.0, device=self.device)) stacked_values = torch.stack(values) return stacked_values.mean() @@ -1011,20 +1014,22 @@ def collect(key): self.log("val_std_reward", std_reward, prog_bar=True, sync_dist=True) mean_metrics = {} - for val_output in self.validation_step_outputs: - batch_metrics = val_output['batch_metrics'] - for item_metrics in batch_metrics: - for key, value in item_metrics.items(): - if "transcript" not in key: - if key not in mean_metrics: - mean_metrics[key] = [] - mean_metrics[key].append(value) + for val_outputs in self.validation_step_outputs: + for val_output in val_outputs: + batch_metrics = val_output['batch_metrics'] + for item_metrics in batch_metrics: + for key, value in item_metrics.items(): + if "transcript" not in key: + if key not in mean_metrics: + mean_metrics[key] = [] + mean_metrics[key].append(value) for key, values in mean_metrics.items(): mean_metrics[key] = np.mean(values) self.log(f"val_{key}", mean_metrics[key], prog_bar=True, sync_dist=True) - self.validation_step_outputs.clear() + for val_outputs in self.validation_step_outputs: + val_outputs.clear() # Utility functions diff --git a/nemo/collections/tts/parts/utils/helpers.py b/nemo/collections/tts/parts/utils/helpers.py index f6ff793de617..1b1855cf356d 100644 --- a/nemo/collections/tts/parts/utils/helpers.py +++ b/nemo/collections/tts/parts/utils/helpers.py @@ -484,6 +484,75 @@ def save_figure_to_numpy(fig): return img_array +def plot_expert_usage_heatmap_to_numpy( + layer_expert_usage: np.ndarray, + ideal_usage: float, + title: str, +) -> np.ndarray: + """ + Render a layer-wise expert usage heatmap and return it as a numpy image. + + Rows = decoder layers (bottom = layer 0), columns = experts. + Cell values are delta from ideal usage (usage - ideal), so 0 = perfectly balanced, + positive = overused, negative = underused. + + Args: + layer_expert_usage: shape (n_layers, num_experts), per-layer expert usage fractions. + ideal_usage: expected uniform value (1 / num_experts). + title: figure title. + + Returns: + numpy array in RGBA HWC format suitable for wandb.Image(). + """ + from matplotlib.colors import TwoSlopeNorm + + n_layers, num_experts = layer_expert_usage.shape + delta = layer_expert_usage - ideal_usage + + row_labels = [f"L{i}" for i in range(n_layers)] + col_labels = [f"E{i}" for i in range(num_experts)] + + abs_max = max(np.abs(delta).max(), 1e-9) + norm = TwoSlopeNorm(vcenter=0.0, vmin=-abs_max, vmax=abs_max) + + dpi = 150 + fig, ax = plt.subplots( + figsize=(max(6, num_experts * 0.55), max(2.4, n_layers * 0.4)), + dpi=dpi, + ) + im = ax.imshow(delta, aspect='auto', cmap='RdBu_r', norm=norm, origin='lower', interpolation='nearest') + + ax.set_xticks(range(num_experts)) + ax.set_xticklabels(col_labels, fontsize=7) + ax.set_yticks(range(n_layers)) + ax.set_yticklabels(row_labels, fontsize=7) + ax.set_xlabel("Experts", fontsize=8) + ax.set_ylabel("Layers", fontsize=8) + ax.set_title(f"{title}\nideal = {ideal_usage:.4f}", fontsize=9, pad=8) + + for row in range(n_layers): + for col in range(num_experts): + ax.text( + col, + row, + f"{delta[row, col]:+.3f}", + ha='center', + va='center', + fontsize=5, + color='white' if abs(delta[row, col]) > abs_max * 0.6 else 'black', + ) + + cbar = fig.colorbar(im, ax=ax, fraction=0.02, pad=0.04) + cbar.ax.tick_params(labelsize=6) + cbar.set_label("Δ from ideal", fontsize=7) + + fig.tight_layout() + fig.canvas.draw() + data = save_figure_to_numpy(fig) + plt.close(fig) + return data + + def regulate_len( durations, enc_out, diff --git a/scripts/magpietts/README_magpie_po.md b/scripts/magpietts/README_magpie_po.md index 897287aaf1d5..d2adb960c488 100644 --- a/scripts/magpietts/README_magpie_po.md +++ b/scripts/magpietts/README_magpie_po.md @@ -78,8 +78,8 @@ batch_size=4 \ max_epochs=10 \ exp_manager.exp_dir= \ exp_manager.checkpoint_callback_params.always_save_nemo=false \ -model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ -model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.train_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.validation_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +train_ds_meta.dpopreftrain.manifest_path="/MagpieTTS-PO-Infer/version_0/manifests/" \ +train_ds_meta.dpopreftrain.audio_dir="/" \ +train_ds_meta.dpopreftrain.feature_dir="/" \ @@ -105,8 +105,8 @@ Note the following overrides in the above command: ``` +mode="dpo_train" \ -model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ -model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.train_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ +model.validation_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \ ``` Again, our manifest contain absolute paths so we specify `audio_dir="/"` . diff --git a/tests/collections/tts/models/test_magpietts_dataloader_config.py b/tests/collections/tts/models/test_magpietts_dataloader_config.py new file mode 100644 index 000000000000..6c0d6ae9c8c5 --- /dev/null +++ b/tests/collections/tts/models/test_magpietts_dataloader_config.py @@ -0,0 +1,355 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for MagpieTTSModel dataloader setup methods. + +Tests for setup_multiple_validation_data (lhotse — datasets is a list): +1. Single lhotse dataset entry +2. Multiple lhotse dataset entries with default/custom names + +Tests for setup_multiple_validation_data (non-lhotse — datasets is a dict): +3. Single dataset_meta in datasets dict +4. Multiple dataset_meta entries in datasets dict +5. Error cases: missing or empty 'datasets' key + +Tests for setup_training_data (non-lhotse — datasets is a dict): +6. Single dataset_meta with num_workers=0 (persistent_workers=False, inline tokenizer setup) +7. Multiple dataset_meta with num_workers>0 (persistent_workers=True, deferred tokenizer setup) + +Tests for setup_training_data (lhotse): +8. Single lhotse input_cfg: sample_rate injection, get_lhotse_dataloader dispatch +9. Multiple lhotse input_cfg entries: weighted multi-source config passes through +""" + +from unittest.mock import MagicMock, patch + +import pytest +from omegaconf import OmegaConf + +from nemo.collections.tts.models.magpietts import MagpieTTSModel + + +class TestSetupMultipleValidationData: + """Test cases for MagpieTTSModel.setup_multiple_validation_data method.""" + + @pytest.fixture + def mock_model(self): + """Create a mock MagpieTTSModel instance with required methods mocked.""" + model = MagicMock(spec=MagpieTTSModel) + model._update_dataset_config = MagicMock() + model._setup_test_dataloader = MagicMock(return_value=MagicMock()) + return model + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_single_dataset_merges_shared_config(self, mock_model): + """Single dataset entry merges shared config, strips 'name', and stores dataloader as a list.""" + config = OmegaConf.create( + { + 'use_lhotse': True, + 'batch_duration': 100, # Shared config + 'num_workers': 2, # Shared config + 'datasets': [ + { + 'name': 'custom_single_val', + 'batch_duration': 50, # Override shared + 'input_cfg': [{'type': 'lhotse_shar', 'shar_path': '/path/to/data'}], + } + ], + } + ) + + MagpieTTSModel.setup_multiple_validation_data(mock_model, config) + + # Single dataset goes through the same unified loop as multi-dataset + mock_model._setup_test_dataloader.assert_called_once() + passed_config = mock_model._setup_test_dataloader.call_args[0][0] + assert passed_config.batch_duration == 50 # Dataset override wins + assert passed_config.num_workers == 2 # Shared config preserved + assert 'name' not in passed_config # 'name' stripped before dataloader setup + assert isinstance(mock_model._validation_dl, list) + assert len(mock_model._validation_dl) == 1 + assert mock_model._validation_names == ['custom_single_val'] + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_multiple_datasets_with_default_and_custom_names(self, mock_model): + """Multiple dataset entries create separate dataloaders and assign default names when unspecified.""" + config = OmegaConf.create( + { + 'use_lhotse': True, + 'batch_duration': 100, + 'datasets': [ + {'input_cfg': [{'type': 'lhotse_shar', 'shar_path': '/path/to/data0'}]}, # No name + {'name': 'custom_name', 'input_cfg': [{'type': 'lhotse_shar', 'shar_path': '/path/to/data1'}]}, + ], + } + ) + + MagpieTTSModel.setup_multiple_validation_data(mock_model, config) + + # Should call _setup_test_dataloader twice (once per dataset) + assert mock_model._setup_test_dataloader.call_count == 2 + assert isinstance(mock_model._validation_dl, list) + assert len(mock_model._validation_dl) == 2 + # First dataset gets default name, second uses explicit name + assert mock_model._validation_names == ['val_set_0', 'custom_name'] + + # ==================== Non-Lhotse (NeMo Manifest) Tests ==================== + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_non_lhotse_datasets_dict_creates_single_dataloader(self, mock_model): + """Non-lhotse: datasets as dict creates a single dataloader, name derived from dataset_meta key.""" + config = OmegaConf.create( + { + 'datasets': { + '_target_': 'nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset', + 'dataset_meta': {'an4': {'manifest_path': '/data/val.json', 'audio_dir': '/'}}, + 'min_duration': 0.2, + 'max_duration': 20.0, + }, + 'dataloader_params': {'batch_size': 4, 'num_workers': 0, 'pin_memory': True}, + } + ) + + MagpieTTSModel.setup_multiple_validation_data(mock_model, config) + + mock_model._setup_test_dataloader.assert_called_once() + passed_config = mock_model._setup_test_dataloader.call_args[0][0] + assert passed_config.datasets._target_ == 'nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset' + assert passed_config.datasets.dataset_meta.an4.manifest_path == '/data/val.json' + assert passed_config.dataloader_params.batch_size == 4 + assert mock_model._validation_names == ['an4'] + assert len(mock_model._validation_dl) == 1 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_non_lhotse_datasets_dict_with_multiple_dataset_meta(self, mock_model): + """Non-lhotse: datasets dict with multiple dataset_meta entries, name joined with '+'.""" + config = OmegaConf.create( + { + 'datasets': { + '_target_': 'nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset', + 'dataset_meta': { + 'en': {'manifest_path': '/data/en_val.json', 'audio_dir': '/'}, + 'es': {'manifest_path': '/data/es_val.json', 'audio_dir': '/'}, + }, + 'min_duration': 0.2, + 'max_duration': 20.0, + }, + 'dataloader_params': {'batch_size': 8, 'num_workers': 2, 'pin_memory': True}, + } + ) + + MagpieTTSModel.setup_multiple_validation_data(mock_model, config) + + mock_model._setup_test_dataloader.assert_called_once() + passed_config = mock_model._setup_test_dataloader.call_args[0][0] + assert passed_config.datasets.dataset_meta.en.manifest_path == '/data/en_val.json' + assert passed_config.datasets.dataset_meta.es.manifest_path == '/data/es_val.json' + assert passed_config.dataloader_params.batch_size == 8 + assert mock_model._validation_names == ['en+es'] + assert len(mock_model._validation_dl) == 1 + + # ==================== Error Case Tests ==================== + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_missing_datasets_key_raises_value_error(self, mock_model): + """Config without 'datasets' key raises ValueError.""" + config = OmegaConf.create({'use_lhotse': True, 'batch_duration': 100}) + + with pytest.raises(ValueError) as exc_info: + MagpieTTSModel.setup_multiple_validation_data(mock_model, config) + + assert "datasets" in str(exc_info.value).lower() + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_empty_datasets_list_raises_value_error(self, mock_model): + """Empty 'datasets' list raises ValueError.""" + config = OmegaConf.create({'use_lhotse': True, 'datasets': []}) + + with pytest.raises(ValueError) as exc_info: + MagpieTTSModel.setup_multiple_validation_data(mock_model, config) + + assert "non-empty list" in str(exc_info.value).lower() + + +class TestSetupTrainingData: + """Test cases for MagpieTTSModel.setup_training_data method (lhotse and non-lhotse paths).""" + + @pytest.fixture + def mock_model(self): + """Create a mock MagpieTTSModel with get_dataset and get_lhotse_dataloader mocked.""" + model = MagicMock(spec=MagpieTTSModel) + mock_dataset = MagicMock() + mock_dataset.get_sampler.return_value = MagicMock() + mock_dataset.collate_fn = MagicMock() + model.get_dataset.return_value = mock_dataset + model.get_lhotse_dataloader.return_value = MagicMock() + model.sample_rate = 22050 + model.trainer.world_size = 1 + return model + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @patch('torch.utils.data.DataLoader', return_value=MagicMock()) + @patch('nemo.collections.tts.models.magpietts.setup_tokenizers', return_value=MagicMock()) + def test_single_non_lhotse_train_dataset_num_workers_zero( + self, mock_setup_tokenizers, mock_dataloader_cls, mock_model + ): + """Single dataset_meta, num_workers=0: tokenizer set up inline, persistent_workers=False.""" + config = OmegaConf.create( + { + 'datasets': { + '_target_': 'nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset', + 'dataset_meta': {'an4': {'manifest_path': '/data/train.json', 'audio_dir': '/'}}, + 'min_duration': 0.2, + 'max_duration': 20.0, + }, + 'dataloader_params': {'batch_size': 4, 'num_workers': 0, 'pin_memory': True, 'drop_last': True}, + } + ) + + MagpieTTSModel.setup_training_data(mock_model, config) + + mock_model.get_dataset.assert_called_once_with(config, dataset_type='train') + mock_dataset = mock_model.get_dataset.return_value + mock_dataset.get_sampler.assert_called_once_with(4, world_size=1) + + mock_setup_tokenizers.assert_called_once() + assert mock_dataset.text_tokenizer == mock_setup_tokenizers.return_value + + mock_dataloader_cls.assert_called_once() + dl_call_kwargs = mock_dataloader_cls.call_args + assert dl_call_kwargs.kwargs['persistent_workers'] is False + assert dl_call_kwargs.kwargs['batch_size'] == 4 + assert dl_call_kwargs.kwargs['num_workers'] == 0 + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + @patch('torch.utils.data.DataLoader', return_value=MagicMock()) + @patch('nemo.collections.tts.models.magpietts.setup_tokenizers', return_value=MagicMock()) + def test_multiple_non_lhotse_train_datasets_num_workers_positive( + self, mock_setup_tokenizers, mock_dataloader_cls, mock_model + ): + """Multiple dataset_meta entries, num_workers>0: tokenizer deferred to worker_init_fn, persistent_workers=True.""" + config = OmegaConf.create( + { + 'datasets': { + '_target_': 'nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset', + 'dataset_meta': { + 'en': {'manifest_path': '/data/en_train.json', 'audio_dir': '/'}, + 'es': {'manifest_path': '/data/es_train.json', 'audio_dir': '/'}, + }, + 'weighted_sampling_steps_per_epoch': 1000, + 'min_duration': 0.2, + 'max_duration': 20.0, + }, + 'dataloader_params': {'batch_size': 16, 'num_workers': 4, 'pin_memory': True, 'drop_last': True}, + } + ) + + MagpieTTSModel.setup_training_data(mock_model, config) + + mock_model.get_dataset.assert_called_once_with(config, dataset_type='train') + mock_dataset = mock_model.get_dataset.return_value + mock_dataset.get_sampler.assert_called_once_with(16, world_size=1) + + mock_setup_tokenizers.assert_not_called() + + mock_dataloader_cls.assert_called_once() + dl_call_kwargs = mock_dataloader_cls.call_args + assert dl_call_kwargs.kwargs['persistent_workers'] is True + assert dl_call_kwargs.kwargs['batch_size'] == 16 + assert dl_call_kwargs.kwargs['num_workers'] == 4 + + # ==================== Lhotse Tests ==================== + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_single_lhotse_train_dataset(self, mock_model): + """Single lhotse input_cfg: sample_rate injected into config and get_lhotse_dataloader called.""" + config = OmegaConf.create( + { + 'use_lhotse': True, + 'volume_norm': True, + 'min_duration': 0.2, + 'batch_duration': 100, + 'num_workers': 4, + 'input_cfg': [ + { + 'type': 'lhotse_shar', + 'shar_path': '/path/to/data', + 'weight': 1.0, + 'tags': {'tokenizer_names': ['english_phoneme']}, + } + ], + } + ) + + MagpieTTSModel.setup_training_data(mock_model, config) + + mock_model.get_lhotse_dataloader.assert_called_once() + passed_config = mock_model.get_lhotse_dataloader.call_args.args[0] + assert passed_config.sample_rate == 22050 + assert passed_config.use_lhotse is True + assert len(passed_config.input_cfg) == 1 + assert passed_config.input_cfg[0].shar_path == '/path/to/data' + assert mock_model.get_lhotse_dataloader.call_args.kwargs['mode'] == 'train' + mock_model.get_dataset.assert_not_called() + + @pytest.mark.run_only_on('CPU') + @pytest.mark.unit + def test_multiple_lhotse_train_datasets(self, mock_model): + """Multiple lhotse input_cfg entries: all data sources passed through with sample_rate injected.""" + config = OmegaConf.create( + { + 'use_lhotse': True, + 'volume_norm': True, + 'min_duration': 0.2, + 'batch_duration': 200, + 'num_workers': 6, + 'input_cfg': [ + { + 'type': 'lhotse_shar', + 'shar_path': '/path/to/en_data', + 'weight': 0.7, + 'tags': {'tokenizer_names': ['english_phoneme']}, + }, + { + 'type': 'lhotse_shar', + 'shar_path': '/path/to/es_data', + 'weight': 0.3, + 'tags': {'tokenizer_names': ['spanish_phoneme']}, + }, + ], + } + ) + + MagpieTTSModel.setup_training_data(mock_model, config) + + mock_model.get_lhotse_dataloader.assert_called_once() + passed_config = mock_model.get_lhotse_dataloader.call_args.args[0] + assert passed_config.sample_rate == 22050 + assert len(passed_config.input_cfg) == 2 + assert passed_config.input_cfg[0].shar_path == '/path/to/en_data' + assert passed_config.input_cfg[0].weight == 0.7 + assert passed_config.input_cfg[1].shar_path == '/path/to/es_data' + assert passed_config.input_cfg[1].weight == 0.3 + assert mock_model.get_lhotse_dataloader.call_args.kwargs['mode'] == 'train' + mock_model.get_dataset.assert_not_called()