Skip to content

Commit abdfbf0

Browse files
committed
remove resize code from prev. trained model - Did not work with all model types
1 parent 7f62f4b commit abdfbf0

File tree

1 file changed

+1
-14
lines changed

1 file changed

+1
-14
lines changed

sentence_transformers/cross_encoder/CrossEncoder.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,11 @@ def __init__(self, model_name:str, num_labels:int = None, max_length:int = None,
4343
if num_labels is None and not classifier_trained:
4444
num_labels = 1
4545

46-
resize_num_labels = None
4746
if num_labels is not None:
48-
if hasattr(self.config, 'num_labels') and self.config.num_labels is not None and self.config.num_labels != num_labels:
49-
#Resize classifier head
50-
resize_num_labels = num_labels
51-
else:
52-
self.config.num_labels = num_labels
47+
self.config.num_labels = num_labels
5348

5449
self.model = AutoModelForSequenceClassification.from_pretrained(model_name, config=self.config)
5550
self.tokenizer = AutoTokenizer.from_pretrained(model_name, **tokenizer_args)
56-
57-
if resize_num_labels is not None:
58-
print("Warning: Loaded model was trained for {} labels. Resize CrossEncoder to {} labels. You must re-train this model to get meaningful predictions".format(self.config.num_labels, resize_num_labels))
59-
self.config.num_labels = resize_num_labels
60-
self.model.config.num_labels = resize_num_labels
61-
self.model.classifier = torch.nn.Linear(self.config.hidden_size, resize_num_labels)
62-
6351
self.max_length = max_length
6452

6553
if device is None:
@@ -214,7 +202,6 @@ def fit(self,
214202
logits = activation_fct(model_predictions.logits)
215203
if self.config.num_labels == 1:
216204
logits = logits.view(-1)
217-
218205
loss_value = loss_fct(logits, labels)
219206
loss_value.backward()
220207
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_grad_norm)

0 commit comments

Comments
 (0)