Commit 41ad3a2
* Update data.py
If a column has been defined as a ClassLabel then `sample_dataset` strips that information away and you lose names.
Test code
```python
from datasets import load_dataset
import datasets
from sentence_transformers.losses import CosineSimilarityLoss
from setfit import SetFitModel, SetFitTrainer, sample_dataset
# Load a dataset from the Hugging Face Hub
dataset: datasets.DatasetDict = load_dataset("SetFit/sst5")
dataset = dataset.class_encode_column("label_text")
# Simulate the few-shot regime by sampling 8 examples per class
train_dataset: datasets.Dataset = sample_dataset(dataset["train"], label_column="label_text", num_samples=8)
eval_dataset: datasets.Dataset = dataset["validation"]
# Load a SetFit model from Hub
model = SetFitModel.from_pretrained("sentence-transformers/paraphrase-mpnet-base-v2")
# Create trainer
trainer = SetFitTrainer(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss_class=CosineSimilarityLoss,
metric="accuracy",
batch_size=16,
num_iterations=20, # The number of text pairs to generate for contrastive learning
num_epochs=1, # The number of epochs to use for contrastive learning
column_mapping={"text": "text", "label_text": "label"} # Map dataset columns to text/label expected by trainer
)
# Train and evaluate
trainer.train()
metrics = trainer.evaluate()
# Run inference
preds = model(["i loved the spiderman movie!", "pineapple on pizza is the worst 🤮"])
print(preds)
print(list(map(lambda x: train_dataset.features["label_text"].names[x], preds)))
```
* Preserve features when calling sample_dataset + tests
---------
Co-authored-by: Tom Aarsen <[email protected]>
1 parent a0b69b4 commit 41ad3a2
2 files changed
+9
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
173 | 173 | | |
174 | 174 | | |
175 | 175 | | |
176 | | - | |
| 176 | + | |
177 | 177 | | |
178 | 178 | | |
179 | 179 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
225 | 225 | | |
226 | 226 | | |
227 | 227 | | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
0 commit comments