Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
fec7616
added multiple validation dataloaders and log metrics per val data. (…
XuesongYang Dec 16, 2025
0d7723a
removed dataclass for media data.
XuesongYang Feb 28, 2026
9ed637b
add moe expert usage monitoring during training
XuesongYang Feb 28, 2026
eaecadb
unify single/multi validation dataloader handling via local property …
XuesongYang Mar 1, 2026
15d4c3d
update moe yaml config.
XuesongYang Mar 1, 2026
daa251e
fix DDP crash: keep MoE expert usage stats on GPU for sync_dist
XuesongYang Mar 1, 2026
f4123c3
make media artifact names using two digits for examples.
XuesongYang Mar 2, 2026
8f08024
replace wandb.Table MoE monitoring with per-expert scalars and layer-…
XuesongYang Mar 2, 2026
5294495
fix MoE heatmap rendering and clean up MoE metric logging
XuesongYang Mar 2, 2026
d3136c1
update layer index with two digits
XuesongYang Mar 2, 2026
85822f9
fix WandB step alignment for validation media and heatmaps and remove…
XuesongYang Mar 3, 2026
8a3b2dc
Apply isort and black reformatting
XuesongYang Mar 3, 2026
9508546
refactor validation optional metrics into loop-driven collection for …
XuesongYang Mar 3, 2026
d193417
unify non-lhotse config key from to (dict) for train/val/test
XuesongYang Mar 3, 2026
b0d908e
fix docs and readmes.
XuesongYang Mar 4, 2026
8606555
bugfix: fix docstring reST formatting and update stale config key in …
XuesongYang Mar 4, 2026
b44a526
fix PO models to use list-of-lists validation_step_outputs
XuesongYang Mar 4, 2026
7b2d1af
fixed longform docs bug.
XuesongYang Mar 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions docs/source/tts/magpietts-longform.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ The input text is split into individual sentences using punctuation markers (``.
Step 2: State Initialization
----------------------------

A ``LongformChunkState`` object is created to track information across sentence chunks:
A ``ChunkState`` object is created to track information across sentence chunks:

- **History text tokens**: Text from previous chunks for context
- **History encoder context**: Encoder outputs that provide continuity
Expand Down Expand Up @@ -112,7 +112,7 @@ Key Components

1. **Sentence Splitting** (``split_by_sentence``): Intelligently splits text on sentence boundaries while handling abbreviations (e.g., "Dr.", "Mr.").

2. **Chunk State** (``LongformChunkState``): Maintains context across chunks:
2. **Chunk State** (``ChunkState``): Maintains context across chunks:

- ``history_text``: Text tokens from previous chunks
- ``history_context_tensor``: Encoder outputs for continuity
Expand Down Expand Up @@ -211,24 +211,24 @@ Configuration Dataclasses
#########################


``LongformConfig``
------------------
``ChunkedInferenceConfig``
--------------------------

Immutable tuning parameters (set in model):

.. literalinclude:: ../../../nemo/collections/tts/models/magpietts.py
:language: python
:pyobject: LongformConfig
:pyobject: ChunkedInferenceConfig


``LongformChunkState``
----------------------
``ChunkState``
--------------

Mutable state passed between chunk iterations:

.. literalinclude:: ../../../nemo/collections/tts/models/magpietts.py
:language: python
:pyobject: LongformChunkState
:pyobject: ChunkState


Best Practices
Expand Down
4 changes: 2 additions & 2 deletions docs/source/tts/magpietts-po.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ The final step is fine-tuning the base model on the preference pairs using the D
max_epochs=10 \
exp_manager.exp_dir=/path/to/dpo_experiment \
exp_manager.checkpoint_callback_params.always_save_nemo=false \
model.train_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
model.validation_ds.dataset._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
model.train_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
model.validation_ds.datasets._target_="nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDatasetDPO" \
+train_ds_meta.dpopreftrain.manifest_path="/path/to/manifests/" \
+train_ds_meta.dpopreftrain.audio_dir="/" \
+train_ds_meta.dpopreftrain.feature_dir="/" \
Expand Down
7 changes: 5 additions & 2 deletions examples/tts/conf/magpietts/magpietts.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ model:
# pretrained_model: "google/byt5-small"

train_ds:
dataset:
datasets:
_target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset
dataset_meta: ${train_ds_meta}
weighted_sampling_steps_per_epoch: ${weighted_sampling_steps_per_epoch}
Expand All @@ -93,8 +93,11 @@ model:
drop_last: true
pin_memory: true

# Non-lhotse validation uses a single dataloader. All dataset_meta entries are mixed
# together, so validation metrics are logged jointly. For per-dataset validation
# metrics, use the lhotse config (magpietts_lhotse.yaml) with separate datasets entries.
validation_ds:
dataset:
datasets:
_target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset
dataset_meta: ${val_ds_meta}
min_duration: 0.2
Expand Down
105 changes: 54 additions & 51 deletions examples/tts/conf/magpietts/magpietts_lhotse.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Magpie-TTS

quadratic_duration: 20 # both training and validation datasets can apply same quadratic_duration.

model:
use_lhotse: true
model_type: "decoder_ce" # decoder_context_tts or decoder_ce
Expand All @@ -16,7 +17,7 @@ model:
alignment_loss_scale: 0.002
embedding_dim: 768
codecmodel_path: ???
cfg_unconditional_prob: 0.1
cfg_unconditional_prob: 0.1 # enable classifier-free guidance during traing by dropping out conditionals with this probability

# Alignment encoder parameters, to binarize the prior
# This is used for attention-constrained training and inference
Expand Down Expand Up @@ -70,57 +71,60 @@ model:
train_ds:
use_lhotse: ${model.use_lhotse}
volume_norm: true

dataset:
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration : ??? # in seconds. Adjust based on your GPU memory.
quadratic_duration: ${quadratic_duration}
use_bucketing: true
num_buckets: 20
bucket_buffer_size: 20_000
shuffle_buffer_size: 20_000
num_cuts_for_bins_estimate: 20_000
shard_seed: "trng"
drop_last: true
shuffle: true
num_workers: 6
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration: ??? # in seconds. Adjust based on your GPU memory.
quadratic_duration: ${quadratic_duration}
use_bucketing: true
num_buckets: 20
bucket_buffer_size: 20_000
shuffle_buffer_size: 20_000
num_cuts_for_bins_estimate: 20_000
shard_seed: "trng"
drop_last: true
shuffle: true
num_workers: 6
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]

validation_ds:
# the entries under 'datasets' are a list of separate dataloaders.
# The structure is:
# - name: '<dataset-name>'
# <dataloader-dict-config>
# They inherit all settings from validation_ds, but can individually override them.
use_lhotse: ${model.use_lhotse}
volume_norm: true

dataset:
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset.
quadratic_duration: ${quadratic_duration}
use_bucketing: false
force_finite: true
force_map_dataset: true
seed: 42
shard_seed: "randomized"
drop_last: false
shuffle: false
num_workers: 2
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset.
quadratic_duration: ${quadratic_duration}
use_bucketing: false
force_finite: true
force_map_dataset: true
seed: 42
shard_seed: "randomized"
drop_last: false
shuffle: false
num_workers: 2
pin_memory: true

datasets:
- name: "val_set_0" # rename to your dataset name, add more as needed
input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]

encoder:
n_layers: 6
Expand Down Expand Up @@ -185,10 +189,9 @@ trainer:
precision: 32
max_steps: ???
accumulate_grad_batches: 1
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager
enable_checkpointing: false # Provided by exp_manager
logger: false # Provided by exp_manager
log_every_n_steps: 100
check_val_every_n_epoch: 1
limit_train_batches: 1_000
val_check_interval: 1_000
num_sanity_val_steps: 0
Expand Down
97 changes: 50 additions & 47 deletions examples/tts/conf/magpietts/magpietts_lhotse_moe.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,57 +74,60 @@ model:
train_ds:
use_lhotse: ${model.use_lhotse}
volume_norm: true

dataset:
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration : ??? # in seconds. Adjust based on your GPU memory.
quadratic_duration: ${quadratic_duration}
use_bucketing: true
num_buckets: 20
bucket_buffer_size: 20_000
shuffle_buffer_size: 20_000
num_cuts_for_bins_estimate: 20_000
shard_seed: "trng"
drop_last: true
shuffle: true
num_workers: 6
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration: ??? # in seconds. Adjust based on your GPU memory.
quadratic_duration: ${quadratic_duration}
use_bucketing: true
num_buckets: 20
bucket_buffer_size: 20_000
shuffle_buffer_size: 20_000
num_cuts_for_bins_estimate: 20_000
shard_seed: "trng"
drop_last: true
shuffle: true
num_workers: 6
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]

validation_ds:
# the entries under 'datasets' are a list of separate dataloaders.
# The structure is:
# - name: '<dataset-name>'
# <dataloader-dict-config>
# They inherit all settings from validation_ds, but can individually override them.
use_lhotse: ${model.use_lhotse}
volume_norm: true

dataset:
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset.
quadratic_duration: ${quadratic_duration}
use_bucketing: false
force_finite: true
force_map_dataset: true
seed: 42
shard_seed: "randomized"
drop_last: false
shuffle: false
num_workers: 2
pin_memory: true

input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]
min_duration: 0.2
min_context_speaker_similarity: 0.6
max_cer: 0.03
batch_duration: ??? # recommend to use smaller batch_duration for validation dataset than training dataset.
quadratic_duration: ${quadratic_duration}
use_bucketing: false
force_finite: true
force_map_dataset: true
seed: 42
shard_seed: "randomized"
drop_last: false
shuffle: false
num_workers: 2
pin_memory: true

datasets:
- name: "val_set_0" # rename to your dataset name, add more as needed
input_cfg:
- type: lhotse_shar
shar_path: ???
weight: 1.0
tags:
tokenizer_names: ["english_phoneme"]

encoder:
n_layers: 6
Expand Down
2 changes: 1 addition & 1 deletion examples/tts/conf/magpietts/magpietts_po_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ model:
# pretrained_model: "google/byt5-small"

test_ds:
dataset:
datasets:
_target_: nemo.collections.tts.data.text_to_speech_dataset.MagpieTTSDataset
dataset_meta: ${test_ds_meta}
min_duration: 0.2
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/tts/losses/moe_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.nn.functional as F

from nemo.core.classes import Loss, typecheck
from nemo.core.neural_types.elements import LossType, ProbsType
from nemo.core.neural_types.elements import LogitsType, LossType, ProbsType
from nemo.core.neural_types.neural_type import NeuralType


Expand Down Expand Up @@ -122,7 +122,7 @@ def __init__(self, loss_scale: float = 0.001):
@property
def input_types(self):
return {
"router_logits": NeuralType(('B', 'T', 'D'), ProbsType()), # D = num_experts
"router_logits": NeuralType(('B', 'T', 'D'), LogitsType()), # D = num_experts
"x_mask": NeuralType(('B', 'T'), ProbsType(), optional=True),
}

Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
@property
def input_types(self):
return {
"router_logits": NeuralType(('B', 'T', 'D'), ProbsType()),
"router_logits": NeuralType(('B', 'T', 'D'), LogitsType()),
"router_probs": NeuralType(('B', 'T', 'D'), ProbsType()),
"x_mask": NeuralType(('B', 'T'), ProbsType(), optional=True),
}
Expand Down
Loading
Loading