Skip to content

Commit 684b6b5

Browse files
Fix smart_batching_collate Inefficiency (#2556)
* Fix smart_batching_collate Inefficiency SentenceTransformer.py:846 throws a Inefficiency warning: ".....Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:275.) labels = torch.tensor([example.label for example in batch])" * Update SentenceTransformer.py * Remove some comments; add edge case (if labels is empty) --------- Co-authored-by: Tom Aarsen <[email protected]>
1 parent 5f75ce5 commit 684b6b5

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

sentence_transformers/SentenceTransformer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,8 +1000,16 @@ def smart_batching_collate(self, batch: List["InputExample"]) -> Tuple[List[Dict
10001000
"""
10011001
texts = [example.texts for example in batch]
10021002
sentence_features = [self.tokenize(sentence) for sentence in zip(*texts)]
1003-
labels = torch.tensor([example.label for example in batch])
1004-
return sentence_features, labels
1003+
labels = [example.label for example in batch]
1004+
1005+
# Use torch.from_numpy to convert the numpy array directly to a tensor,
1006+
# which is the recommended approach for converting numpy arrays to tensors
1007+
if labels and isinstance(labels[0], np.ndarray):
1008+
labels_tensor = torch.from_numpy(np.stack(labels))
1009+
else:
1010+
labels_tensor = torch.tensor(labels)
1011+
1012+
return sentence_features, labels_tensor
10051013

10061014
def _text_length(self, text: Union[List[int], List[List[int]]]):
10071015
"""

0 commit comments

Comments
 (0)