Skip to content

Commit aa6f482

Browse files
authored
Merge branch 'k2-fsa:master' into dev/speechllm
2 parents 7c30dd5 + ffb7d05 commit aa6f482

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

egs/librispeech/ASR/zipformer/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,10 @@ def forward_cr_ctc(
210210
)
211211

212212
# Compute consistency regularization loss
213-
exchanged_targets = ctc_output.detach().chunk(2, dim=0)
214-
exchanged_targets = torch.cat(
215-
[exchanged_targets[1], exchanged_targets[0]], dim=0
216-
) # exchange: [x1, x2] -> [x2, x1]
213+
batch_size = ctc_output.shape[0]
214+
assert batch_size % 2 == 0, batch_size
215+
# exchange: [x1, x2] -> [x2, x1]
216+
exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0)
217217
cr_loss = nn.functional.kl_div(
218218
input=ctc_output,
219219
target=exchanged_targets,

0 commit comments

Comments
 (0)