@@ -49,7 +49,7 @@ def __init__(
49
49
hf_split : str ,
50
50
hf_subset : str ,
51
51
** kwargs ,
52
- ):
52
+ ) -> None :
53
53
super ().__init__ (** kwargs )
54
54
self .sentences1 = sentences1
55
55
self .sentences2 = sentences2
@@ -67,7 +67,7 @@ def __call__(
67
67
self ,
68
68
model : Encoder ,
69
69
encode_kwargs : dict [str , Any ],
70
- ):
70
+ ) -> dict [ str , float ] :
71
71
scores = self .compute_metrics (model , encode_kwargs = encode_kwargs )
72
72
73
73
# Main score is the max of Average Precision (AP)
@@ -83,7 +83,7 @@ def _encode_unique_texts(
83
83
hf_split : str ,
84
84
hf_subset : str ,
85
85
** encode_kwargs : Any ,
86
- ):
86
+ ) -> np . ndarray :
87
87
index_map , all_unique_texts , all_texts_indexes = {}, [], []
88
88
for text in all_texts :
89
89
text_hash = hash (text )
@@ -110,7 +110,7 @@ def compute_metrics(
110
110
model : Encoder ,
111
111
* ,
112
112
encode_kwargs : dict [str , Any ],
113
- ):
113
+ ) -> dict [ str , float ] :
114
114
all_sentences = self .sentences1 + self .sentences2
115
115
len_sentences1 = len (self .sentences1 )
116
116
embeddings = self ._encode_unique_texts (
@@ -215,7 +215,9 @@ def _compute_metrics(
215
215
}
216
216
217
217
@staticmethod
218
- def find_best_acc_and_threshold (scores , labels , high_score_more_similar : bool ):
218
+ def find_best_acc_and_threshold (
219
+ scores : np .ndarray , labels : np .ndarray , high_score_more_similar : bool
220
+ ) -> tuple [float , float ]:
219
221
assert len (scores ) == len (labels )
220
222
rows = list (zip (scores , labels ))
221
223
@@ -242,7 +244,9 @@ def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):
242
244
return max_acc , best_threshold
243
245
244
246
@staticmethod
245
- def find_best_f1_and_threshold (scores , labels , high_score_more_similar : bool ):
247
+ def find_best_f1_and_threshold (
248
+ scores , labels , high_score_more_similar : bool
249
+ ) -> tuple [float , float , float , float ]:
246
250
assert len (scores ) == len (labels )
247
251
248
252
scores = np .asarray (scores )
@@ -278,7 +282,7 @@ def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool):
278
282
return best_f1 , best_precision , best_recall , threshold
279
283
280
284
@staticmethod
281
- def ap_score (scores , labels , high_score_more_similar : bool ):
285
+ def ap_score (scores , labels , high_score_more_similar : bool ) -> float :
282
286
return average_precision_score (
283
287
labels , scores * (1 if high_score_more_similar else - 1 )
284
288
)
0 commit comments