10
10
from datasets import Dataset , DatasetDict
11
11
from sklearn .cluster import MiniBatchKMeans
12
12
from sklearn .metrics .cluster import v_measure_score
13
- from torch .utils .data import DataLoader
14
13
15
14
from mteb .models import Encoder
16
15
from mteb .types import HFSubset
17
16
from mteb .types .statistics import DescriptiveStatistics , LabelStatistics , TextStatistics
18
17
18
+ from ..create_dataloaders import create_dataloader
19
19
from ._statistics_calculation import (
20
20
calculate_label_statistics ,
21
21
calculate_text_statistics ,
@@ -126,6 +126,8 @@ class AbsTaskClusteringFast(AbsTask):
126
126
k_mean_batch_size : int = 512
127
127
max_depth = None
128
128
abstask_prompt = "Identify categories in user passages."
129
+ input_column_name : str = "sentences"
130
+ label_column_name : str = "labels"
129
131
130
132
def _evaluate_subset (
131
133
self ,
@@ -164,19 +166,24 @@ def _evaluate_subset(
164
166
)
165
167
downsampled_dataset = dataset .select (example_indices ) # type: ignore
166
168
167
- downsampled_dataset = downsampled_dataset .rename_column (
168
- original_column_name = "sentences" , new_column_name = "text"
169
+ downsampled_dataset = downsampled_dataset .select_columns (
170
+ [ self . input_column_name , self . label_column_name ]
169
171
)
170
172
embeddings = model .encode (
171
- DataLoader (downsampled_dataset ),
173
+ create_dataloader (
174
+ downsampled_dataset ,
175
+ self .metadata ,
176
+ input_column = self .input_column_name ,
177
+ batch_size = encode_kwargs ["batch_size" ],
178
+ ),
172
179
task_metadata = self .metadata ,
173
180
hf_subset = hf_subset ,
174
181
hf_split = hf_split ,
175
182
** encode_kwargs ,
176
183
)
177
184
178
185
labels = []
179
- for label in downsampled_dataset ["labels" ]:
186
+ for label in downsampled_dataset [self . label_column_name ]:
180
187
if not isinstance (label , list ):
181
188
label = [label ]
182
189
labels .append (label )
@@ -194,29 +201,27 @@ def _evaluate_subset(
194
201
195
202
mean_v_measure = np .mean (v_measures )
196
203
v_std = np .std (v_measures )
197
- scores = {
204
+ return {
198
205
"v_measures" : all_v_scores ,
199
206
"v_measure" : float (mean_v_measure ),
200
207
"v_measure_std" : v_std ,
201
208
}
202
- self ._add_main_score (scores )
203
- return scores
204
209
205
210
def _calculate_descriptive_statistics_from_split (
206
211
self , split : str , hf_subset : str | None = None , compute_overall : bool = False
207
212
) -> ClusteringFastDescriptiveStatistics :
208
213
if hf_subset :
209
- sentences = self .dataset [hf_subset ][split ]["sentences" ]
210
- labels = self .dataset [hf_subset ][split ]["labels" ]
214
+ sentences = self .dataset [hf_subset ][split ][self . input_column_name ]
215
+ labels = self .dataset [hf_subset ][split ][self . label_column_name ]
211
216
elif compute_overall :
212
217
sentences = []
213
218
labels = []
214
219
for hf_subset in self .metadata .eval_langs :
215
- sentences .extend (self .dataset [hf_subset ][split ]["sentences" ])
216
- labels .extend (self .dataset [hf_subset ][split ]["labels" ])
220
+ sentences .extend (self .dataset [hf_subset ][split ][self . input_column_name ])
221
+ labels .extend (self .dataset [hf_subset ][split ][self . label_column_name ])
217
222
else :
218
- sentences = self .dataset [split ]["sentences" ]
219
- labels = self .dataset [split ]["labels" ]
223
+ sentences = self .dataset [split ][self . input_column_name ]
224
+ labels = self .dataset [split ][self . label_column_name ]
220
225
221
226
return ClusteringFastDescriptiveStatistics (
222
227
num_samples = len (sentences ),
@@ -225,11 +230,17 @@ def _calculate_descriptive_statistics_from_split(
225
230
)
226
231
227
232
def _push_dataset_to_hub (self , repo_name : str ) -> None :
228
- self ._upload_dataset_to_hub (repo_name , ["sentences" , "labels" ])
233
+ self ._upload_dataset_to_hub (
234
+ repo_name , [self .input_column_name , self .label_column_name ]
235
+ )
229
236
230
237
231
238
def convert_to_fast (
232
- dataset : DatasetDict , seed : int , max_size : int = 100_000
239
+ dataset : DatasetDict ,
240
+ input_column_name : str ,
241
+ label_column_name : str ,
242
+ seed : int ,
243
+ max_size : int = 100_000 ,
233
244
) -> DatasetDict :
234
245
"""Converts a clustering dataset to a fast version. This concats the cluster into two columns, sentences and labels.
235
246
It additionally downsamples the dataset to max_size.
@@ -242,10 +253,12 @@ def convert_to_fast(
242
253
labels = []
243
254
sentences = []
244
255
n_clusters = len (dataset [split ])
245
- all_labels_set = set (itertools .chain .from_iterable (dataset [split ]["labels" ]))
256
+ all_labels_set = set (
257
+ itertools .chain .from_iterable (dataset [split ][label_column_name ])
258
+ )
246
259
for i in range (n_clusters ):
247
- lab = dataset [split ]["labels" ][i ]
248
- sents = dataset [split ]["sentences" ][i ]
260
+ lab = dataset [split ][label_column_name ][i ]
261
+ sents = dataset [split ][input_column_name ][i ]
249
262
250
263
# check that it is the same distribution
251
264
row_label_set = set (lab )
@@ -259,7 +272,9 @@ def convert_to_fast(
259
272
sentences .append (s )
260
273
sent_set .add (s ) # ensuring no duplicates
261
274
262
- ds [split ] = Dataset .from_dict ({"sentences" : sentences , "labels" : labels })
275
+ ds [split ] = Dataset .from_dict (
276
+ {input_column_name : sentences , label_column_name : labels }
277
+ )
263
278
264
279
if len (ds [split ]) > max_size :
265
280
idxs = rng_state .sample (range (len (ds [split ])), max_size )
@@ -268,17 +283,20 @@ def convert_to_fast(
268
283
return DatasetDict (ds )
269
284
270
285
271
- def check_label_distribution (ds : DatasetDict ) -> None :
286
+ def check_label_distribution (
287
+ ds : DatasetDict ,
288
+ label_column_name : str = "labels" ,
289
+ ) -> None :
272
290
"""For older clustering dataset versions.
273
291
ds is a DatasetDict at the split level
274
292
"""
275
293
n_clusters = len (ds )
276
294
if n_clusters > 50 :
277
295
return
278
- all_labels_set = set (itertools .chain .from_iterable (ds ["labels" ]))
296
+ all_labels_set = set (itertools .chain .from_iterable (ds [label_column_name ]))
279
297
280
298
for i in range (n_clusters ):
281
- lab = ds ["labels" ][i ]
299
+ lab = ds [label_column_name ][i ]
282
300
283
301
# check that it is the same distribution
284
302
row_label_set = set (lab )
0 commit comments