We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents 7c30dd5 + ffb7d05 commit aa6f482Copy full SHA for aa6f482
egs/librispeech/ASR/zipformer/model.py
@@ -210,10 +210,10 @@ def forward_cr_ctc(
210
)
211
212
# Compute consistency regularization loss
213
- exchanged_targets = ctc_output.detach().chunk(2, dim=0)
214
- exchanged_targets = torch.cat(
215
- [exchanged_targets[1], exchanged_targets[0]], dim=0
216
- ) # exchange: [x1, x2] -> [x2, x1]
+ batch_size = ctc_output.shape[0]
+ assert batch_size % 2 == 0, batch_size
+ # exchange: [x1, x2] -> [x2, x1]
+ exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
217
cr_loss = nn.functional.kl_div(
218
input=ctc_output,
219
target=exchanged_targets,
0 commit comments