-
Notifications
You must be signed in to change notification settings - Fork 54
Open
Description
I am training my own text and motion embedding models for evaluation. I noticed in the TextMotionMatchTrainer class, there is a potential bug in the shift applied to create negative examples for the contrastive loss.
def backward(self):
batch_size = self.text_embedding.shape[0]
'''Positive pairs'''
pos_labels = torch.zeros(batch_size).to(self.text_embedding.device)
self.loss_pos = self.contrastive_loss(self.text_embedding, self.motion_embedding, pos_labels)
'''Negative Pairs, shifting index'''
neg_labels = torch.ones(batch_size).to(self.text_embedding.device)
shift = np.random.randint(0, batch_size-1) # BUG
new_idx = np.arange(shift, batch_size + shift) % batch_size
self.mis_motion_embedding = self.motion_embedding.clone()[new_idx]
self.loss_neg = self.contrastive_loss(self.text_embedding, self.mis_motion_embedding, neg_labels)
self.loss = self.loss_pos + self.loss_neg
loss_logs = OrderedDict({})
loss_logs['loss'] = self.loss.item()
loss_logs['loss_pos'] = self.loss_pos.item()
loss_logs['loss_neg'] = self.loss_neg.item()
return loss_logsIf we shift 0, then the "negative" examples with be compared to itself. This is especially problematic when training with low batch sizes, like in the README (batch size 8). The correction is
shift = np.random.randint(1, batch_size-1)After doing this, I see improved training curves. Below, the grey curve is with the bug fix and the purple curve is with the original code. Batch size is 8.
Shi-Qi-Li and ChanHyeok-Choi
Metadata
Metadata
Assignees
Labels
No labels
