We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
batchmean
1 parent 7e20753 commit 0450f05Copy full SHA for 0450f05
trl/trainer/gkd_trainer.py
@@ -270,7 +270,7 @@ def generalized_jsd_loss(
270
271
# Apply reduction
272
if reduction == "batchmean":
273
- return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
+ return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0)
274
elif reduction == "sum":
275
return jsd.sum()
276
elif reduction == "mean":
0 commit comments