Skip to content

Commit 0450f05

Browse files
authored
[GKD] Fix batchmean reduce op in GKDTrainer's loss (#4105)
1 parent 7e20753 commit 0450f05

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

trl/trainer/gkd_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def generalized_jsd_loss(
270270

271271
# Apply reduction
272272
if reduction == "batchmean":
273-
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / (jsd.size(0) * jsd.size(1))
273+
return jsd.sum() / mask.sum() if labels is not None else jsd.sum() / jsd.size(0)
274274
elif reduction == "sum":
275275
return jsd.sum()
276276
elif reduction == "mean":

0 commit comments

Comments
 (0)