Skip to content

Commit 4ebee43

Browse files
authored
Normalize device to CPU when evaluating (#363)
1 parent 1ecd91c commit 4ebee43

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

src/setfit/trainer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import evaluate
55
import numpy as np
6+
import torch
67
from datasets import Dataset, DatasetDict
78
from sentence_transformers import InputExample, losses
89
from sentence_transformers.datasets import SentenceLabelDataset
@@ -438,6 +439,8 @@ def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]:
438439

439440
logger.info("***** Running evaluation *****")
440441
y_pred = self.model.predict(x_test)
442+
if isinstance(y_pred, torch.Tensor):
443+
y_pred = y_pred.cpu()
441444

442445
if isinstance(self.metric, str):
443446
metric_config = "multilabel" if self.model.multi_target_strategy is not None else None

tests/test_trainer.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import evaluate
77
import pytest
8+
import torch
89
from datasets import Dataset, load_dataset
910
from sentence_transformers import losses
1011
from transformers.testing_utils import require_optuna
@@ -497,3 +498,27 @@ def test_trainer_evaluate_multilabel_f1():
497498
trainer.train()
498499
metrics = trainer.evaluate()
499500
assert metrics == {"f1": 1.0}
501+
502+
503+
def test_trainer_evaluate_on_cpu() -> None:
504+
# This test used to fail if CUDA was available
505+
dataset = Dataset.from_dict(
506+
{"text": ["positive sentence", "negative sentence"], "label": ["positive", "negative"]}
507+
)
508+
model = SetFitModel.from_pretrained(
509+
"sentence-transformers/paraphrase-albert-small-v2", use_differentiable_head=True
510+
)
511+
512+
def compute_metric(y_pred, y_test) -> None:
513+
assert y_pred.device == torch.device("cpu")
514+
return 1.0
515+
516+
trainer = SetFitTrainer(
517+
model=model,
518+
train_dataset=dataset,
519+
eval_dataset=dataset,
520+
metric=compute_metric,
521+
num_iterations=5,
522+
)
523+
trainer.train()
524+
trainer.evaluate()

0 commit comments

Comments
 (0)