Skip to content

Commit 01a86e9

Browse files
authored
[v2] Start type checking (#3176)
* start type check integration * fix imports * fix validation * fix required import * fix comments
1 parent 45114a5 commit 01a86e9

33 files changed

+309
-206
lines changed

Makefile

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,3 +80,8 @@ format-citations:
8080
check: ## Run code quality tools.
8181
@echo "--- 🧹 Running code quality tools ---"
8282
@pre-commit run -a
83+
84+
.PHONY: typecheck
85+
typecheck:
86+
@echo "--- 🔍 Running type checks ---"
87+
mypy mteb

mteb/_evaluators/any_sts_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def __init__(
3131
hf_split: str,
3232
hf_subset: str,
3333
**kwargs,
34-
):
34+
) -> None:
3535
super().__init__(**kwargs)
3636
self.first_column = create_dataloader(
3737
dataset,
@@ -53,7 +53,7 @@ def __call__(
5353
model: Encoder,
5454
*,
5555
encode_kwargs: dict[str, Any],
56-
):
56+
) -> dict[str, float]:
5757
embeddings1 = model.encode(
5858
self.first_column,
5959
task_metadata=self.task_metadata,

mteb/_evaluators/classification_evaluator.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,25 @@ def calculate_scores(
111111
)
112112
return scores
113113

114-
def __call__(
114+
def __call__( # type: ignore[override]
115115
self,
116116
model: Encoder,
117117
*,
118118
encode_kwargs: dict[str, Any],
119119
test_cache: np.ndarray | None = None,
120-
) -> tuple[dict[str, float], Any]:
120+
) -> tuple[dict[str, float], np.ndarray]:
121+
"""Classification evaluation by training a sklearn classifier on the
122+
embeddings of the training set and evaluating on the embeddings of the test set.
123+
124+
Args:
125+
model: Encoder
126+
encode_kwargs: encode kwargs
127+
test_cache: embeddings of the test set, if already computed
128+
129+
Returns:
130+
Tuple of scores and test embeddings
131+
132+
"""
121133
dataloader_train, dataloader_test = self.create_dataloaders(
122134
batch_size=encode_kwargs["batch_size"]
123135
)

mteb/_evaluators/clustering_evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(
2727
hf_subset: str,
2828
clustering_batch_size: int = 500,
2929
**kwargs,
30-
):
30+
) -> None:
3131
super().__init__(**kwargs)
3232
self.dataset = dataset
3333
self.clustering_batch_size = clustering_batch_size
@@ -43,7 +43,7 @@ def __call__(
4343
*,
4444
encode_kwargs: dict[str, Any],
4545
v_measure_only: bool = False,
46-
):
46+
) -> dict[str, float]:
4747
data_loader = create_dataloader(
4848
self.dataset,
4949
self.task_metadata,

mteb/_evaluators/evaluator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@ class Evaluator(ABC):
1212
Extend this class and implement __call__ for custom evaluators.
1313
"""
1414

15-
def __init__(self, seed: int = 42, **kwargs: Any):
15+
def __init__(self, seed: int = 42, **kwargs: Any) -> None:
1616
self.seed = seed
1717
self.rng_state, self.np_rng = set_seed(seed)
1818

1919
@abstractmethod
20-
def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any]):
20+
def __call__(
21+
self, model: Encoder, *, encode_kwargs: dict[str, Any]
22+
) -> dict[str, float]:
2123
"""This is called during training to evaluate the model.
2224
It returns scores.
2325

mteb/_evaluators/regression_evaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
self.task_metadata = task_metadata
5353
self.regressor = regressor
5454

55-
def __call__(
55+
def __call__( # type: ignore[override]
5656
self,
5757
model: Encoder,
5858
*,

mteb/_evaluators/retrieval_evaluator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def __init__(
3434
top_ranked: TopRankedDocumentsType | None = None,
3535
qid: str | None = None,
3636
**kwargs,
37-
):
37+
) -> None:
3838
super().__init__(**kwargs)
3939
self.corpus = corpus
4040
self.queries = queries
@@ -46,11 +46,10 @@ def __init__(
4646
self.qid = qid
4747
self.top_k = top_k
4848

49-
def __call__(
49+
def __call__( # type: ignore[override]
5050
self,
5151
search_model: SearchProtocol,
5252
encode_kwargs: dict[str, Any],
53-
**kwargs: Any,
5453
) -> RetrievalOutputType:
5554
search_model.index(
5655
corpus=self.corpus,

mteb/_evaluators/text/bitext_mining_evaluator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
hf_subset: str,
3131
pair_columns: list[tuple[str, str]] = DEFAULT_PAIR,
3232
**kwargs,
33-
):
33+
) -> None:
3434
super().__init__(**kwargs)
3535
self.pairs = pair_columns
3636
self.n = len(sentences)
@@ -45,11 +45,15 @@ def __init__(
4545
self.hf_subset = hf_subset
4646
self.task_metadata = task_metadata
4747

48-
def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any]):
48+
def __call__(
49+
self, model: Encoder, *, encode_kwargs: dict[str, Any]
50+
) -> dict[str, float]:
4951
scores = self.compute_metrics(model, encode_kwargs=encode_kwargs)
5052
return scores
5153

52-
def compute_metrics(self, model: Encoder, encode_kwargs: dict[str, Any]):
54+
def compute_metrics(
55+
self, model: Encoder, encode_kwargs: dict[str, Any]
56+
) -> dict[str, float]:
5357
pair_elements = {p for pair in self.pairs for p in pair}
5458
if isinstance(self.sentences, Dataset):
5559
subsets = [

mteb/_evaluators/text/pair_classification_evaluator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949
hf_split: str,
5050
hf_subset: str,
5151
**kwargs,
52-
):
52+
) -> None:
5353
super().__init__(**kwargs)
5454
self.sentences1 = sentences1
5555
self.sentences2 = sentences2
@@ -67,7 +67,7 @@ def __call__(
6767
self,
6868
model: Encoder,
6969
encode_kwargs: dict[str, Any],
70-
):
70+
) -> dict[str, float]:
7171
scores = self.compute_metrics(model, encode_kwargs=encode_kwargs)
7272

7373
# Main score is the max of Average Precision (AP)
@@ -83,7 +83,7 @@ def _encode_unique_texts(
8383
hf_split: str,
8484
hf_subset: str,
8585
**encode_kwargs: Any,
86-
):
86+
) -> np.ndarray:
8787
index_map, all_unique_texts, all_texts_indexes = {}, [], []
8888
for text in all_texts:
8989
text_hash = hash(text)
@@ -110,7 +110,7 @@ def compute_metrics(
110110
model: Encoder,
111111
*,
112112
encode_kwargs: dict[str, Any],
113-
):
113+
) -> dict[str, float]:
114114
all_sentences = self.sentences1 + self.sentences2
115115
len_sentences1 = len(self.sentences1)
116116
embeddings = self._encode_unique_texts(
@@ -215,7 +215,9 @@ def _compute_metrics(
215215
}
216216

217217
@staticmethod
218-
def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):
218+
def find_best_acc_and_threshold(
219+
scores: np.ndarray, labels: np.ndarray, high_score_more_similar: bool
220+
) -> tuple[float, float]:
219221
assert len(scores) == len(labels)
220222
rows = list(zip(scores, labels))
221223

@@ -242,7 +244,9 @@ def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):
242244
return max_acc, best_threshold
243245

244246
@staticmethod
245-
def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool):
247+
def find_best_f1_and_threshold(
248+
scores, labels, high_score_more_similar: bool
249+
) -> tuple[float, float, float, float]:
246250
assert len(scores) == len(labels)
247251

248252
scores = np.asarray(scores)
@@ -278,7 +282,7 @@ def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool):
278282
return best_f1, best_precision, best_recall, threshold
279283

280284
@staticmethod
281-
def ap_score(scores, labels, high_score_more_similar: bool):
285+
def ap_score(scores, labels, high_score_more_similar: bool) -> float:
282286
return average_precision_score(
283287
labels, scores * (1 if high_score_more_similar else -1)
284288
)

mteb/_evaluators/text/summarization_evaluator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
hf_split: str,
3737
hf_subset: str,
3838
**kwargs,
39-
):
39+
) -> None:
4040
"""Summarization Evaluator
4141
4242
Args:
@@ -63,7 +63,7 @@ def __call__(
6363
model: Encoder,
6464
*,
6565
encode_kwargs: dict[str, Any],
66-
):
66+
) -> dict[str, float]:
6767
cosine_spearman_scores = []
6868
cosine_pearson_scores = []
6969
dot_spearman_scores = []
@@ -196,7 +196,7 @@ def __init__(
196196
hf_split: str | None = None,
197197
hf_subset: str | None = None,
198198
**kwargs,
199-
):
199+
) -> None:
200200
# human_summaries shape: (None, num_human_summaries)
201201
# machine_summaries shape: (None, num_machine_summaries)
202202
# gold scores shape: (None, num_machine_summaries)
@@ -220,7 +220,7 @@ def __call__(
220220
model: Encoder,
221221
*,
222222
encode_kwargs: dict[str, Any],
223-
):
223+
) -> dict[str, float]:
224224
cosine_spearman_scores = []
225225
cosine_pearson_scores = []
226226
dot_spearman_scores = []

0 commit comments

Comments
 (0)