diff --git a/clip_benchmark/metrics/zeroshot_classification.py b/clip_benchmark/metrics/zeroshot_classification.py index 9c70e1e..4b974df 100644 --- a/clip_benchmark/metrics/zeroshot_classification.py +++ b/clip_benchmark/metrics/zeroshot_classification.py @@ -34,8 +34,8 @@ def zero_shot_classifier(model, tokenizer, classnames, templates, device, amp=Tr Returns ------- - torch.Tensor of shape (N,C) where N is the number - of templates, and C is the number of classes. + torch.Tensor of shape (D,C) where D is the projection + dimension, and C is the number of classes. """ with torch.no_grad(), torch.autocast(device, enabled=amp): zeroshot_weights = []