Skip to content

Commit 41ad3a2

Browse files
groftetomaarsen
andauthored
Preserve dataset.features when using sample_dataset (#396)
* Update data.py If a column has been defined as a ClassLabel then `sample_dataset` strips that information away and you lose names. Test code ```python from datasets import load_dataset import datasets from sentence_transformers.losses import CosineSimilarityLoss from setfit import SetFitModel, SetFitTrainer, sample_dataset # Load a dataset from the Hugging Face Hub dataset: datasets.DatasetDict = load_dataset("SetFit/sst5") dataset = dataset.class_encode_column("label_text") # Simulate the few-shot regime by sampling 8 examples per class train_dataset: datasets.Dataset = sample_dataset(dataset["train"], label_column="label_text", num_samples=8) eval_dataset: datasets.Dataset = dataset["validation"] # Load a SetFit model from Hub model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2") # Create trainer trainer = SetFitTrainer( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, loss_class=CosineSimilarityLoss, metric="accuracy", batch_size=16, num_iterations=20, # The number of text pairs to generate for contrastive learning num_epochs=1, # The number of epochs to use for contrastive learning column_mapping={"text": "text", "label_text": "label"} # Map dataset columns to text/label expected by trainer ) # Train and evaluate trainer.train() metrics = trainer.evaluate() # Run inference preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"]) print(preds) print(list(map(lambda x: train_dataset.features["label_text"].names[x], preds))) ``` * Preserve features when calling sample_dataset + tests --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent a0b69b4 commit 41ad3a2

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

src/setfit/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def sample_dataset(dataset: Dataset, label_column: str = "label", num_samples: i
173173
df = df.apply(lambda x: x.sample(min(num_samples, len(x))))
174174
df = df.reset_index(drop=True)
175175

176-
all_samples = Dataset.from_pandas(df)
176+
all_samples = Dataset.from_pandas(df, features=dataset.features)
177177
return all_samples.shuffle(seed=seed)
178178

179179

tests/test_data.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,11 @@ def test_correct_model_inputs(tokenizer_name):
225225
# Verify that the x_batch contains exactly those keys that the model requires
226226
x_batch, _ = next(iter(dataloader))
227227
assert set(x_batch.keys()) == set(tokenizer.model_input_names)
228+
229+
230+
def test_preserve_features() -> None:
231+
dataset = load_dataset("SetFit/sst5", split="train[:100]")
232+
label_column = "label_text"
233+
dataset = dataset.class_encode_column(label_column)
234+
train_dataset = sample_dataset(dataset, label_column=label_column, num_samples=8)
235+
assert train_dataset.features[label_column] == dataset.features[label_column]

0 commit comments

Comments
 (0)