Skip to content

Commit 1ecd91c

Browse files
groftetomaarsen
andauthored
Allow other datasets in trainer.evaluate() (#402)
* evaluate with other dataset - update test_trainer.py * Evaluate on a different dataset - update trainer.py * Add missing Dataset import + make style --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 41ad3a2 commit 1ecd91c

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

src/setfit/trainer.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import evaluate
55
import numpy as np
6-
from datasets import DatasetDict
6+
from datasets import Dataset, DatasetDict
77
from sentence_transformers import InputExample, losses
88
from sentence_transformers.datasets import SentenceLabelDataset
99
from sentence_transformers.losses.BatchHardTripletLoss import BatchHardTripletLossDistanceFunction
@@ -19,7 +19,6 @@
1919

2020
if TYPE_CHECKING:
2121
import optuna
22-
from datasets import Dataset
2322

2423
from .modeling import SetFitModel
2524

@@ -415,20 +414,24 @@ def train(
415414
show_progress_bar=True,
416415
)
417416

418-
def evaluate(self) -> Dict[str, float]:
417+
def evaluate(self, dataset: Optional[Dataset] = None) -> Dict[str, float]:
419418
"""
420419
Computes the metrics for a given classifier.
421420
421+
Args:
422+
dataset (`Dataset`, *optional*):
423+
The dataset to compute the metrics on. If not provided, will use the evaluation dataset passed in the eval_dataset argument at `SetFitTrainer` initialization.
424+
422425
Returns:
423426
`Dict[str, float]`: The evaluation metrics.
424427
"""
425428

426-
self._validate_column_mapping(self.eval_dataset)
427-
eval_dataset = self.eval_dataset
429+
eval_dataset = dataset or self.eval_dataset
430+
self._validate_column_mapping(eval_dataset)
428431

429432
if self.column_mapping is not None:
430433
logger.info("Applying column mapping to evaluation dataset")
431-
eval_dataset = self._apply_column_mapping(self.eval_dataset, self.column_mapping)
434+
eval_dataset = self._apply_column_mapping(eval_dataset, self.column_mapping)
432435

433436
x_test = eval_dataset["text"]
434437
y_test = eval_dataset["label"]

tests/test_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,18 @@ def test_trainer_works_with_default_columns(self):
6868
metrics = trainer.evaluate()
6969
self.assertEqual(metrics["accuracy"], 1.0)
7070

71+
def test_trainer_works_with_alternate_dataset_for_evaluate(self):
72+
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]})
73+
alternate_dataset = Dataset.from_dict(
74+
{"text": ["x", "y", "z"], "label": [0, 1, 2], "extra_column": ["d", "e", "f"]}
75+
)
76+
trainer = SetFitTrainer(
77+
model=self.model, train_dataset=dataset, eval_dataset=dataset, num_iterations=self.num_iterations
78+
)
79+
trainer.train()
80+
metrics = trainer.evaluate(alternate_dataset)
81+
self.assertNotEqual(metrics["accuracy"], 1.0)
82+
7183
def test_trainer_raises_error_with_missing_label(self):
7284
dataset = Dataset.from_dict({"text": ["a", "b", "c"], "extra_column": ["d", "e", "f"]})
7385
trainer = SetFitTrainer(

0 commit comments

Comments
 (0)