Skip to content

Potential Bug in TextMotionMatchTrainer #39

@mpiseno

Description

@mpiseno

I am training my own text and motion embedding models for evaluation. I noticed in the TextMotionMatchTrainer class, there is a potential bug in the shift applied to create negative examples for the contrastive loss.

def backward(self):

        batch_size = self.text_embedding.shape[0]
        '''Positive pairs'''
        pos_labels = torch.zeros(batch_size).to(self.text_embedding.device)
        self.loss_pos = self.contrastive_loss(self.text_embedding, self.motion_embedding, pos_labels)

        '''Negative Pairs, shifting index'''
        neg_labels = torch.ones(batch_size).to(self.text_embedding.device)
        shift = np.random.randint(0, batch_size-1) # BUG
        new_idx = np.arange(shift, batch_size + shift) % batch_size
        self.mis_motion_embedding = self.motion_embedding.clone()[new_idx]
        self.loss_neg = self.contrastive_loss(self.text_embedding, self.mis_motion_embedding, neg_labels)
        self.loss = self.loss_pos + self.loss_neg

        loss_logs = OrderedDict({})
        loss_logs['loss'] = self.loss.item()
        loss_logs['loss_pos'] = self.loss_pos.item()
        loss_logs['loss_neg'] = self.loss_neg.item()
        return loss_logs

If we shift 0, then the "negative" examples with be compared to itself. This is especially problematic when training with low batch sizes, like in the README (batch size 8). The correction is

shift = np.random.randint(1, batch_size-1)

After doing this, I see improved training curves. Below, the grey curve is with the bug fix and the purple curve is with the original code. Batch size is 8.

Screenshot 2024-01-18 at 6 01 22 PM

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions