Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions mteb/_evaluators/any_sts_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
hf_split: str,
hf_subset: str,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.first_column = create_dataloader(
dataset,
Expand All @@ -53,7 +53,7 @@ def __call__(
model: Encoder,
*,
encode_kwargs: dict[str, Any],
):
) -> dict[str, float]:
embeddings1 = model.encode(
self.first_column,
task_metadata=self.task_metadata,
Expand Down
4 changes: 2 additions & 2 deletions mteb/_evaluators/classification_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,13 @@ def calculate_scores(
)
return scores

def __call__(
def __call__( # type: ignore[override]
self,
model: Encoder,
*,
encode_kwargs: dict[str, Any],
test_cache: np.ndarray | None = None,
) -> tuple[dict[str, float], Any]:
) -> tuple[dict[str, float], np.ndarray | None]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a docstring would be good here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

dataloader_train, dataloader_test = self.create_dataloaders(
batch_size=encode_kwargs["batch_size"]
)
Expand Down
4 changes: 2 additions & 2 deletions mteb/_evaluators/clustering_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
hf_subset: str,
clustering_batch_size: int = 500,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.dataset = dataset
self.clustering_batch_size = clustering_batch_size
Expand All @@ -43,7 +43,7 @@ def __call__(
*,
encode_kwargs: dict[str, Any],
v_measure_only: bool = False,
):
) -> dict[str, float]:
data_loader = create_dataloader(
self.dataset,
self.task_metadata,
Expand Down
6 changes: 4 additions & 2 deletions mteb/_evaluators/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@ class Evaluator(ABC):
Extend this class and implement __call__ for custom evaluators.
"""

def __init__(self, seed: int = 42, **kwargs: Any):
def __init__(self, seed: int = 42, **kwargs: Any) -> None:
self.seed = seed
self.rng_state, self.np_rng = set_seed(seed)

@abstractmethod
def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any]):
def __call__(
self, model: Encoder, *, encode_kwargs: dict[str, Any]
) -> dict[str, float]:
"""This is called during training to evaluate the model.
It returns scores.

Expand Down
2 changes: 1 addition & 1 deletion mteb/_evaluators/regression_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(
self.task_metadata = task_metadata
self.regressor = regressor

def __call__(
def __call__( # type: ignore[override]
self,
model: Encoder,
*,
Expand Down
5 changes: 2 additions & 3 deletions mteb/_evaluators/retrieval_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(
top_ranked: TopRankedDocumentsType | None = None,
qid: str | None = None,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.corpus = corpus
self.queries = queries
Expand All @@ -46,11 +46,10 @@ def __init__(
self.qid = qid
self.top_k = top_k

def __call__(
def __call__( # type: ignore[override]
self,
search_model: SearchProtocol,
encode_kwargs: dict[str, Any],
**kwargs: Any,
) -> RetrievalOutputType:
search_model.index(
corpus=self.corpus,
Expand Down
10 changes: 7 additions & 3 deletions mteb/_evaluators/text/bitext_mining_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(
hf_subset: str,
pair_columns: list[tuple[str, str]] = DEFAULT_PAIR,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.pairs = pair_columns
self.n = len(sentences)
Expand All @@ -45,11 +45,15 @@ def __init__(
self.hf_subset = hf_subset
self.task_metadata = task_metadata

def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any]):
def __call__(
self, model: Encoder, *, encode_kwargs: dict[str, Any]
) -> dict[str, float]:
scores = self.compute_metrics(model, encode_kwargs=encode_kwargs)
return scores

def compute_metrics(self, model: Encoder, encode_kwargs: dict[str, Any]):
def compute_metrics(
self, model: Encoder, encode_kwargs: dict[str, Any]
) -> dict[str, float]:
pair_elements = {p for pair in self.pairs for p in pair}
if isinstance(self.sentences, Dataset):
subsets = [
Expand Down
18 changes: 11 additions & 7 deletions mteb/_evaluators/text/pair_classification_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
hf_split: str,
hf_subset: str,
**kwargs,
):
) -> None:
super().__init__(**kwargs)
self.sentences1 = sentences1
self.sentences2 = sentences2
Expand All @@ -67,7 +67,7 @@ def __call__(
self,
model: Encoder,
encode_kwargs: dict[str, Any],
):
) -> dict[str, float]:
scores = self.compute_metrics(model, encode_kwargs=encode_kwargs)

# Main score is the max of Average Precision (AP)
Expand All @@ -83,7 +83,7 @@ def _encode_unique_texts(
hf_split: str,
hf_subset: str,
**encode_kwargs: Any,
):
) -> np.ndarray:
index_map, all_unique_texts, all_texts_indexes = {}, [], []
for text in all_texts:
text_hash = hash(text)
Expand All @@ -110,7 +110,7 @@ def compute_metrics(
model: Encoder,
*,
encode_kwargs: dict[str, Any],
):
) -> dict[str, float]:
all_sentences = self.sentences1 + self.sentences2
len_sentences1 = len(self.sentences1)
embeddings = self._encode_unique_texts(
Expand Down Expand Up @@ -215,7 +215,9 @@ def _compute_metrics(
}

@staticmethod
def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):
def find_best_acc_and_threshold(
scores: np.ndarray, labels: np.ndarray, high_score_more_similar: bool
) -> tuple[float, float]:
assert len(scores) == len(labels)
rows = list(zip(scores, labels))

Expand All @@ -242,7 +244,9 @@ def find_best_acc_and_threshold(scores, labels, high_score_more_similar: bool):
return max_acc, best_threshold

@staticmethod
def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool):
def find_best_f1_and_threshold(
scores, labels, high_score_more_similar: bool
) -> tuple[float, float, float, float]:
assert len(scores) == len(labels)

scores = np.asarray(scores)
Expand Down Expand Up @@ -278,7 +282,7 @@ def find_best_f1_and_threshold(scores, labels, high_score_more_similar: bool):
return best_f1, best_precision, best_recall, threshold

@staticmethod
def ap_score(scores, labels, high_score_more_similar: bool):
def ap_score(scores, labels, high_score_more_similar: bool) -> float:
return average_precision_score(
labels, scores * (1 if high_score_more_similar else -1)
)
8 changes: 4 additions & 4 deletions mteb/_evaluators/text/summarization_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
hf_split: str,
hf_subset: str,
**kwargs,
):
) -> None:
"""Summarization Evaluator

Args:
Expand All @@ -63,7 +63,7 @@ def __call__(
model: Encoder,
*,
encode_kwargs: dict[str, Any],
):
) -> dict[str, float]:
cosine_spearman_scores = []
cosine_pearson_scores = []
dot_spearman_scores = []
Expand Down Expand Up @@ -196,7 +196,7 @@ def __init__(
hf_split: str | None = None,
hf_subset: str | None = None,
**kwargs,
):
) -> None:
# human_summaries shape: (None, num_human_summaries)
# machine_summaries shape: (None, num_machine_summaries)
# gold scores shape: (None, num_machine_summaries)
Expand All @@ -220,7 +220,7 @@ def __call__(
model: Encoder,
*,
encode_kwargs: dict[str, Any],
):
) -> dict[str, float]:
cosine_spearman_scores = []
cosine_pearson_scores = []
dot_spearman_scores = []
Expand Down
6 changes: 4 additions & 2 deletions mteb/_evaluators/zeroshot_classification_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(
hf_split: str,
hf_subset: str,
**kwargs,
):
) -> None:
super().__init__(**kwargs)

self.dataset = dataset
Expand All @@ -42,7 +42,9 @@ def __init__(
self.hf_split = hf_split
self.hf_subset = hf_subset

def __call__(self, model: Encoder, *, encode_kwargs: dict[str, Any]):
def __call__(
self, model: Encoder, *, encode_kwargs: dict[str, Any]
) -> dict[str, float]:
if "image" in self.task_metadata.modalities:
dataloader = create_image_dataloader(
self.dataset,
Expand Down
17 changes: 9 additions & 8 deletions mteb/abstasks/AbsTask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
SearchProtocol,
)
from mteb.set_seed import set_seed
from mteb.types import HFSubset, ScoresDict
from mteb.types.statistics import DescriptiveStatistics
from mteb.types import HFSubset, Modalities, ScoresDict
from mteb.types.statistics import DescriptiveStatistics, SplitDescriptiveStatistics

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -191,8 +191,9 @@ def evaluate(
@abstractmethod
def _evaluate_subset(
self,
model: MTEBModels,
model: Encoder,
data_split: Dataset,
*,
encode_kwargs: dict[str, Any],
hf_split: str,
hf_subset: str,
Expand Down Expand Up @@ -336,7 +337,7 @@ def fast_load(self, **kwargs: Any) -> None:

def calculate_descriptive_statistics(
self, overwrite_results: bool = False
) -> dict[str, DescriptiveStatistics | dict[str, DescriptiveStatistics]]:
) -> dict[str, DescriptiveStatistics]:
"""Calculates descriptive statistics from the dataset."""
from mteb.abstasks import AbsTaskAnyClassification

Expand All @@ -347,7 +348,7 @@ def calculate_descriptive_statistics(
if not self.data_loaded:
self.load_data()

descriptive_stats = {}
descriptive_stats: dict[str, DescriptiveStatistics] = {}
hf_subset_stat = "hf_subset_descriptive_stats"
eval_splits = self.metadata.eval_splits
if isinstance(self, AbsTaskAnyClassification):
Expand Down Expand Up @@ -387,15 +388,15 @@ def calculate_descriptive_statistics(

def calculate_metadata_metrics(
self, overwrite_results: bool = False
) -> dict[str, DescriptiveStatistics | dict[str, DescriptiveStatistics]]:
) -> dict[str, DescriptiveStatistics]:
return self.calculate_descriptive_statistics(
overwrite_results=overwrite_results
)

@abstractmethod
def _calculate_descriptive_statistics_from_split(
self, split: str, hf_subset: str | None = None, compute_overall: bool = False
) -> DescriptiveStatistics:
) -> SplitDescriptiveStatistics:
raise NotImplementedError

@property
Expand Down Expand Up @@ -584,7 +585,7 @@ def eval_splits(self) -> list[str]:
return self.metadata.eval_splits

@property
def modalities(self) -> list[str]:
def modalities(self) -> list[Modalities]:
"""Returns the modalities of the task."""
return self.metadata.modalities

Expand Down
Loading
Loading