Skip to content

Commit 5d9b873

Browse files
authored
[v2] Select required columns to run (#3190)
* select required columns to run * fix tasks
1 parent 713d661 commit 5d9b873

11 files changed

+140
-90
lines changed

mteb/_evaluators/clustering_evaluator.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
from datasets import Dataset
77
from scipy.optimize import linear_sum_assignment
88
from sklearn import cluster, metrics
9-
from torch.utils.data import DataLoader
109

1110
from mteb.abstasks.task_metadata import TaskMetadata
12-
from mteb.create_dataloaders import create_image_dataloader
11+
from mteb.create_dataloaders import create_dataloader
1312
from mteb.models import Encoder
1413

1514
from .evaluator import Evaluator
@@ -38,29 +37,19 @@ def __init__(
3837
self.hf_split = hf_split
3938
self.hf_subset = hf_subset
4039

41-
def create_dataloader(self, batch_size: int) -> DataLoader:
42-
if self.task_metadata.modalities == ["image"]:
43-
return create_image_dataloader(
44-
self.dataset,
45-
image_column_name=self.input_column_name,
46-
batch_size=batch_size,
47-
)
48-
elif self.task_metadata.modalities == ["text"]:
49-
return DataLoader(self.dataset)
50-
else:
51-
raise ValueError(
52-
f"Unsupported modality {self.task_metadata.modalities}. "
53-
"Currently only 'image' modality is supported."
54-
)
55-
5640
def __call__(
5741
self,
5842
model: Encoder,
5943
*,
6044
encode_kwargs: dict[str, Any],
6145
v_measure_only: bool = False,
6246
):
63-
data_loader = self.create_dataloader(batch_size=encode_kwargs["batch_size"])
47+
data_loader = create_dataloader(
48+
self.dataset,
49+
self.task_metadata,
50+
input_column=self.input_column_name,
51+
batch_size=encode_kwargs["batch_size"],
52+
)
6453

6554
embeddings = model.encode(
6655
data_loader,

mteb/abstasks/AbsTaskAnyClassification.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ def evaluate(
109109
ds = self.dataset
110110
else:
111111
ds = self.dataset[hf_subset]
112+
113+
if isinstance(ds, (Dataset, DatasetDict)):
114+
ds = ds.select_columns([self.label_column_name, self.input_column_name])
112115
scores[hf_subset] = self._evaluate_subset(
113116
model,
114117
ds,

mteb/abstasks/AbsTaskAnyClustering.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ def _evaluate_subset(
7676
):
7777
v_measures = []
7878
for cluster_set in tqdm.tqdm(dataset, desc="Clustering"):
79-
clustering_dataset = Dataset.from_dict(cluster_set).rename_column(
80-
original_column_name="sentences", new_column_name="text"
79+
clustering_dataset = Dataset.from_dict(cluster_set).select_columns(
80+
[self.input_column_name, self.label_column_name]
8181
)
8282
evaluator = self.evaluator(
8383
clustering_dataset,
@@ -103,6 +103,9 @@ def _evaluate_subset(
103103
self._add_main_score(scores)
104104
return scores
105105

106+
dataset = dataset.select_columns(
107+
[self.input_column_name, self.label_column_name]
108+
)
106109
evaluator = self.evaluator(
107110
dataset,
108111
input_column_name=self.input_column_name,

mteb/abstasks/AbsTaskAnySTS.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,8 @@ def _evaluate_subset(
7979
**kwargs: Any,
8080
) -> ScoresDict:
8181
normalized_scores = list(map(self.normalize, data_split["score"]))
82+
data_split = data_split.select_columns(list(self.column_names))
83+
8284
evaluator = AnySTSEvaluator(
8385
data_split,
8486
self.column_names,

mteb/abstasks/AbsTaskAnyZeroShotClassification.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,6 @@ class AbsTaskAnyZeroShotClassification(AbsTask):
6060
input_column_name: str = "image"
6161
label_column_name: str = "label"
6262

63-
def __init__(self, **kwargs):
64-
super().__init__(**kwargs)
65-
66-
def _add_main_score(self, scores) -> None:
67-
scores["main_score"] = scores[self.metadata.main_score]
68-
6963
def _calculate_descriptive_statistics_from_split(
7064
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
7165
) -> ZeroShotClassificationDescriptiveStatistics:
@@ -114,6 +108,9 @@ def _evaluate_subset(
114108
**kwargs,
115109
) -> ScoresDict:
116110
candidate_labels = self.get_candidate_labels()
111+
dataset = dataset.select_columns(
112+
[self.input_column_name, self.label_column_name]
113+
)
117114
evaluator = ZeroShotClassificationEvaluator(
118115
dataset,
119116
self.input_column_name,

mteb/abstasks/AbsTaskClusteringFast.py

Lines changed: 41 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,12 @@
1010
from datasets import Dataset, DatasetDict
1111
from sklearn.cluster import MiniBatchKMeans
1212
from sklearn.metrics.cluster import v_measure_score
13-
from torch.utils.data import DataLoader
1413

1514
from mteb.models import Encoder
1615
from mteb.types import HFSubset
1716
from mteb.types.statistics import DescriptiveStatistics, LabelStatistics, TextStatistics
1817

18+
from ..create_dataloaders import create_dataloader
1919
from ._statistics_calculation import (
2020
calculate_label_statistics,
2121
calculate_text_statistics,
@@ -126,6 +126,8 @@ class AbsTaskClusteringFast(AbsTask):
126126
k_mean_batch_size: int = 512
127127
max_depth = None
128128
abstask_prompt = "Identify categories in user passages."
129+
input_column_name: str = "sentences"
130+
label_column_name: str = "labels"
129131

130132
def _evaluate_subset(
131133
self,
@@ -164,19 +166,24 @@ def _evaluate_subset(
164166
)
165167
downsampled_dataset = dataset.select(example_indices) # type: ignore
166168

167-
downsampled_dataset = downsampled_dataset.rename_column(
168-
original_column_name="sentences", new_column_name="text"
169+
downsampled_dataset = downsampled_dataset.select_columns(
170+
[self.input_column_name, self.label_column_name]
169171
)
170172
embeddings = model.encode(
171-
DataLoader(downsampled_dataset),
173+
create_dataloader(
174+
downsampled_dataset,
175+
self.metadata,
176+
input_column=self.input_column_name,
177+
batch_size=encode_kwargs["batch_size"],
178+
),
172179
task_metadata=self.metadata,
173180
hf_subset=hf_subset,
174181
hf_split=hf_split,
175182
**encode_kwargs,
176183
)
177184

178185
labels = []
179-
for label in downsampled_dataset["labels"]:
186+
for label in downsampled_dataset[self.label_column_name]:
180187
if not isinstance(label, list):
181188
label = [label]
182189
labels.append(label)
@@ -194,29 +201,27 @@ def _evaluate_subset(
194201

195202
mean_v_measure = np.mean(v_measures)
196203
v_std = np.std(v_measures)
197-
scores = {
204+
return {
198205
"v_measures": all_v_scores,
199206
"v_measure": float(mean_v_measure),
200207
"v_measure_std": v_std,
201208
}
202-
self._add_main_score(scores)
203-
return scores
204209

205210
def _calculate_descriptive_statistics_from_split(
206211
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
207212
) -> ClusteringFastDescriptiveStatistics:
208213
if hf_subset:
209-
sentences = self.dataset[hf_subset][split]["sentences"]
210-
labels = self.dataset[hf_subset][split]["labels"]
214+
sentences = self.dataset[hf_subset][split][self.input_column_name]
215+
labels = self.dataset[hf_subset][split][self.label_column_name]
211216
elif compute_overall:
212217
sentences = []
213218
labels = []
214219
for hf_subset in self.metadata.eval_langs:
215-
sentences.extend(self.dataset[hf_subset][split]["sentences"])
216-
labels.extend(self.dataset[hf_subset][split]["labels"])
220+
sentences.extend(self.dataset[hf_subset][split][self.input_column_name])
221+
labels.extend(self.dataset[hf_subset][split][self.label_column_name])
217222
else:
218-
sentences = self.dataset[split]["sentences"]
219-
labels = self.dataset[split]["labels"]
223+
sentences = self.dataset[split][self.input_column_name]
224+
labels = self.dataset[split][self.label_column_name]
220225

221226
return ClusteringFastDescriptiveStatistics(
222227
num_samples=len(sentences),
@@ -225,11 +230,17 @@ def _calculate_descriptive_statistics_from_split(
225230
)
226231

227232
def _push_dataset_to_hub(self, repo_name: str) -> None:
228-
self._upload_dataset_to_hub(repo_name, ["sentences", "labels"])
233+
self._upload_dataset_to_hub(
234+
repo_name, [self.input_column_name, self.label_column_name]
235+
)
229236

230237

231238
def convert_to_fast(
232-
dataset: DatasetDict, seed: int, max_size: int = 100_000
239+
dataset: DatasetDict,
240+
input_column_name: str,
241+
label_column_name: str,
242+
seed: int,
243+
max_size: int = 100_000,
233244
) -> DatasetDict:
234245
"""Converts a clustering dataset to a fast version. This concats the cluster into two columns, sentences and labels.
235246
It additionally downsamples the dataset to max_size.
@@ -242,10 +253,12 @@ def convert_to_fast(
242253
labels = []
243254
sentences = []
244255
n_clusters = len(dataset[split])
245-
all_labels_set = set(itertools.chain.from_iterable(dataset[split]["labels"]))
256+
all_labels_set = set(
257+
itertools.chain.from_iterable(dataset[split][label_column_name])
258+
)
246259
for i in range(n_clusters):
247-
lab = dataset[split]["labels"][i]
248-
sents = dataset[split]["sentences"][i]
260+
lab = dataset[split][label_column_name][i]
261+
sents = dataset[split][input_column_name][i]
249262

250263
# check that it is the same distribution
251264
row_label_set = set(lab)
@@ -259,7 +272,9 @@ def convert_to_fast(
259272
sentences.append(s)
260273
sent_set.add(s) # ensuring no duplicates
261274

262-
ds[split] = Dataset.from_dict({"sentences": sentences, "labels": labels})
275+
ds[split] = Dataset.from_dict(
276+
{input_column_name: sentences, label_column_name: labels}
277+
)
263278

264279
if len(ds[split]) > max_size:
265280
idxs = rng_state.sample(range(len(ds[split])), max_size)
@@ -268,17 +283,20 @@ def convert_to_fast(
268283
return DatasetDict(ds)
269284

270285

271-
def check_label_distribution(ds: DatasetDict) -> None:
286+
def check_label_distribution(
287+
ds: DatasetDict,
288+
label_column_name: str = "labels",
289+
) -> None:
272290
"""For older clustering dataset versions.
273291
ds is a DatasetDict at the split level
274292
"""
275293
n_clusters = len(ds)
276294
if n_clusters > 50:
277295
return
278-
all_labels_set = set(itertools.chain.from_iterable(ds["labels"]))
296+
all_labels_set = set(itertools.chain.from_iterable(ds[label_column_name]))
279297

280298
for i in range(n_clusters):
281-
lab = ds["labels"][i]
299+
lab = ds[label_column_name][i]
282300

283301
# check that it is the same distribution
284302
row_label_set = set(lab)

mteb/abstasks/AbsTaskMultilabelClassification.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@ def _evaluate_subset(
7777
encode_kwargs: dict[str, Any],
7878
**kwargs: Any,
7979
) -> ScoresDict:
80+
if isinstance(dataset, (Dataset, DatasetDict)):
81+
dataset = dataset.select_columns(
82+
[self.input_column_name, self.label_column_name]
83+
)
8084
train_split = dataset[self.train_split]
8185
eval_split = dataset[hf_split]
8286

mteb/abstasks/AbsTaskPairClassification.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ class AbsTaskPairClassification(AbsTask):
5353
"""
5454

5555
abstask_prompt = "Retrieve text that are semantically similar to the given text."
56+
sentence1_column_name: str = "sentence1"
57+
sentence2_column_name: str = "sentence2"
58+
label_column_name: str = "labels"
5659

5760
def _evaluate_subset(
5861
self,
@@ -69,9 +72,9 @@ def _evaluate_subset(
6972
"sentence_transformers.evaluation.PairClassificationEvaluator"
7073
).setLevel(logging.WARN)
7174
evaluator = PairClassificationEvaluator(
72-
data_split["sentence1"],
73-
data_split["sentence2"],
74-
data_split["labels"],
75+
data_split[self.sentence1_column_name],
76+
data_split[self.sentence2_column_name],
77+
data_split[self.label_column_name],
7578
task_metadata=self.metadata,
7679
hf_split=hf_split,
7780
hf_subset=hf_subset,
@@ -102,17 +105,19 @@ def _calculate_descriptive_statistics_from_split(
102105
dataset = dataset[0]
103106

104107
sentence1 = (
105-
dataset["sentence1"][0]
106-
if len(dataset["sentence1"]) == 1
107-
else dataset["sentence1"]
108+
dataset[self.sentence1_column_name][0]
109+
if len(dataset[self.sentence1_column_name]) == 1
110+
else dataset[self.sentence1_column_name]
108111
)
109112
sentence2 = (
110-
dataset["sentence2"][0]
111-
if len(dataset["sentence2"]) == 1
112-
else dataset["sentence2"]
113+
dataset[self.sentence2_column_name][0]
114+
if len(dataset[self.sentence2_column_name]) == 1
115+
else dataset[self.sentence2_column_name]
113116
)
114117
labels = (
115-
dataset["labels"][0] if len(dataset["labels"]) == 1 else dataset["labels"]
118+
dataset[self.label_column_name][0]
119+
if len(dataset[self.label_column_name]) == 1
120+
else dataset[self.label_column_name]
116121
)
117122

118123
text1_statistics = calculate_text_statistics(sentence1)
@@ -140,4 +145,11 @@ def _push_dataset_to_hub(self, repo_name: str) -> None:
140145
for split in self.dataset:
141146
if len(self.dataset[split]) == 1:
142147
self.dataset[split] = self.dataset[split][0]
143-
self._upload_dataset_to_hub(repo_name, ["sentence1", "sentence2", "labels"])
148+
self._upload_dataset_to_hub(
149+
repo_name,
150+
[
151+
self.sentence1_column_name,
152+
self.sentence2_column_name,
153+
self.label_column_name,
154+
],
155+
)

0 commit comments

Comments
 (0)