|
28 | 28 | from beartype.door import is_bearable |
29 | 29 |
|
30 | 30 | from naturalspeech2_pytorch.attend import Attend |
31 | | -from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss |
| 31 | +from naturalspeech2_pytorch.aligner import Aligner, ForwardSumLoss, BinLoss |
32 | 32 | from naturalspeech2_pytorch.utils.tokenizer import Tokenizer, ESpeak |
33 | 33 | from naturalspeech2_pytorch.utils.utils import average_over_durations, create_mask |
34 | 34 | from naturalspeech2_pytorch.version import __version__ |
@@ -1192,7 +1192,8 @@ def __init__( |
1192 | 1192 | scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images |
1193 | 1193 | duration_loss_weight = 1., |
1194 | 1194 | pitch_loss_weight = 1., |
1195 | | - aligner_loss_weight = 1. |
| 1195 | + aligner_loss_weight = 1., |
| 1196 | + aligner_bin_loss_weight = 0. |
1196 | 1197 | ): |
1197 | 1198 | super().__init__() |
1198 | 1199 |
|
@@ -1233,7 +1234,10 @@ def __init__( |
1233 | 1234 | self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim) |
1234 | 1235 | self.aligner = Aligner(dim_in=aligner_dim_in, dim_hidden=aligner_dim_hidden, attn_channels=aligner_attn_channels) |
1235 | 1236 | self.pitch_emb = nn.Embedding(pitch_emb_dim, pitch_emb_pp_hidden_dim) |
| 1237 | + |
1236 | 1238 | self.aligner_loss = ForwardSumLoss() |
| 1239 | + self.bin_loss = BinLoss() |
| 1240 | + self.aligner_bin_loss_weight = aligner_bin_loss_weight |
1237 | 1241 |
|
1238 | 1242 | # rest of ddpm |
1239 | 1243 |
|
@@ -1584,7 +1588,12 @@ def forward( |
1584 | 1588 |
|
1585 | 1589 | pitch = rearrange(pitch, 'b 1 d -> b d') |
1586 | 1590 | pitch_loss = F.l1_loss(pitch, pitch_pred) |
1587 | | - align_loss = self.aligner_loss(aln_log , text_lens, mel_lens) |
| 1591 | + |
| 1592 | + align_loss = self.aligner_loss(aln_log, text_lens, mel_lens) |
| 1593 | + |
| 1594 | + if self.aligner_bin_loss_weight > 0.: |
| 1595 | + align_bin_loss = self.bin_loss(aln_mask, aln_log, text_lens) * self.aligner_bin_loss_weight |
| 1596 | + align_loss = align_loss + align_bin_loss |
1588 | 1597 |
|
1589 | 1598 | # weigh the losses |
1590 | 1599 |
|
|
0 commit comments