From 47d6ead9f1d9700f5e6ea423367ed0deadf8e72c Mon Sep 17 00:00:00 2001 From: Antoine Chaffin Date: Thu, 9 Jan 2025 15:29:46 +0000 Subject: [PATCH] fix circular important and remove duplicate code --- pylate/evaluation/nano_beir_evaluator.py | 274 +---------------------- 1 file changed, 3 insertions(+), 271 deletions(-) diff --git a/pylate/evaluation/nano_beir_evaluator.py b/pylate/evaluation/nano_beir_evaluator.py index 714257d..ff4b588 100644 --- a/pylate/evaluation/nano_beir_evaluator.py +++ b/pylate/evaluation/nano_beir_evaluator.py @@ -1,29 +1,13 @@ from __future__ import annotations import logging -import os -from typing import TYPE_CHECKING, Callable, Literal +from typing import Literal -import numpy as np -from sentence_transformers import SentenceTransformer from sentence_transformers.evaluation.SentenceEvaluator import \ SentenceEvaluator -from sentence_transformers.similarity_functions import SimilarityFunction from sentence_transformers.util import is_datasets_available -from torch import Tensor -from tqdm import tqdm -from ..evaluation import PyLateInformationRetrievalEvaluator -from ..scores import colbert_scores - -# from sentence_transformers.evaluation.InformationRetrievalEvaluator import \ -# InformationRetrievalEvaluator - - - - -if TYPE_CHECKING: - from sentence_transformers.SentenceTransformer import SentenceTransformer +from .pylate_information_retrieval_evaluator import PyLateInformationRetrievalEvaluator logger = logging.getLogger(__name__) @@ -176,226 +160,6 @@ class NanoBEIREvaluator(SentenceEvaluator): # => 0.8084508771660436 """ - def __init__( - self, - dataset_names: list[DatasetNameType] | None = None, - mrr_at_k: list[int] = [10], - ndcg_at_k: list[int] = [10], - accuracy_at_k: list[int] = [1, 3, 5, 10], - precision_recall_at_k: list[int] = [1, 3, 5, 10], - map_at_k: list[int] = [100], - show_progress_bar: bool = False, - batch_size: int = 32, - write_csv: bool = True, - truncate_dim: int | None = None, - score_functions: dict[str, Callable[[Tensor, Tensor], Tensor]] = None, - main_score_function: str | SimilarityFunction | None = None, - aggregate_fn: Callable[[list[float]], float] = np.mean, - aggregate_key: str = "mean", - query_prompts: str | dict[str, str] | None = None, - corpus_prompts: str | dict[str, str] | None = None, - ): - """ - Initializes the NanoBEIREvaluator. - - Args: - dataset_names (List[str]): The names of the datasets to evaluate on. - mrr_at_k (List[int]): A list of integers representing the values of k for MRR calculation. Defaults to [10]. - ndcg_at_k (List[int]): A list of integers representing the values of k for NDCG calculation. Defaults to [10]. - accuracy_at_k (List[int]): A list of integers representing the values of k for accuracy calculation. Defaults to [1, 3, 5, 10]. - precision_recall_at_k (List[int]): A list of integers representing the values of k for precision and recall calculation. Defaults to [1, 3, 5, 10]. - map_at_k (List[int]): A list of integers representing the values of k for MAP calculation. Defaults to [100]. - show_progress_bar (bool): Whether to show a progress bar during evaluation. Defaults to False. - batch_size (int): The batch size for evaluation. Defaults to 32. - write_csv (bool): Whether to write the evaluation results to a CSV file. Defaults to True. - truncate_dim (int, optional): The dimension to truncate the embeddings to. Defaults to None. - score_functions (Dict[str, Callable[[Tensor, Tensor], Tensor]]): A dictionary mapping score function names to score functions. Defaults to {SimilarityFunction.COSINE.value: cos_sim, SimilarityFunction.DOT_PRODUCT.value: dot_score}. - main_score_function (Union[str, SimilarityFunction], optional): The main score function to use for evaluation. Defaults to None. - aggregate_fn (Callable[[list[float]], float]): The function to aggregate the scores. Defaults to np.mean. - aggregate_key (str): The key to use for the aggregated score. Defaults to "mean". - query_prompts (str | dict[str, str], optional): The prompts to add to the queries. If a string, will add the same prompt to all queries. If a dict, expects that all datasets in dataset_names are keys. - corpus_prompts (str | dict[str, str], optional): The prompts to add to the corpus. If a string, will add the same prompt to all corpus. If a dict, expects that all datasets in dataset_names are keys. - """ - super().__init__() - if dataset_names is None: - dataset_names = list(dataset_name_to_id.keys()) - self.dataset_names = dataset_names - self.aggregate_fn = aggregate_fn - self.aggregate_key = aggregate_key - self.write_csv = write_csv - self.query_prompts = query_prompts - self.corpus_prompts = corpus_prompts - self.show_progress_bar = show_progress_bar - self.write_csv = write_csv - self.score_functions = score_functions - self.score_function_names = sorted(list(self.score_functions.keys())) if score_functions else [] - self.main_score_function = main_score_function - self.truncate_dim = truncate_dim - self.name = f"NanoBEIR_{aggregate_key}" - if self.truncate_dim: - self.name += f"_{self.truncate_dim}" - - self.mrr_at_k = mrr_at_k - self.ndcg_at_k = ndcg_at_k - self.accuracy_at_k = accuracy_at_k - self.precision_recall_at_k = precision_recall_at_k - self.map_at_k = map_at_k - - self._validate_dataset_names() - self._validate_prompts() - - ir_evaluator_kwargs = { - "mrr_at_k": mrr_at_k, - "ndcg_at_k": ndcg_at_k, - "accuracy_at_k": accuracy_at_k, - "precision_recall_at_k": precision_recall_at_k, - "map_at_k": map_at_k, - "show_progress_bar": show_progress_bar, - "batch_size": batch_size, - "write_csv": write_csv, - "truncate_dim": truncate_dim, - "score_functions": score_functions, - "main_score_function": main_score_function, - } - - self.evaluators = [self._load_dataset(name, **ir_evaluator_kwargs) for name in self.dataset_names] - - self.csv_file: str = f"NanoBEIR_evaluation_{aggregate_key}_results.csv" - self.csv_headers = ["epoch", "steps"] - - self._append_csv_headers(self.score_function_names) - - def _append_csv_headers(self, score_function_names): - for score_name in score_function_names: - for k in self.accuracy_at_k: - self.csv_headers.append(f"{score_name}-Accuracy@{k}") - - for k in self.precision_recall_at_k: - self.csv_headers.append(f"{score_name}-Precision@{k}") - self.csv_headers.append(f"{score_name}-Recall@{k}") - - for k in self.mrr_at_k: - self.csv_headers.append(f"{score_name}-MRR@{k}") - - for k in self.ndcg_at_k: - self.csv_headers.append(f"{score_name}-NDCG@{k}") - - for k in self.map_at_k: - self.csv_headers.append(f"{score_name}-MAP@{k}") - - def __call__( - self, model: SentenceTransformer, output_path: str = None, epoch: int = -1, steps: int = -1, *args, **kwargs - ) -> dict[str, float]: - per_metric_results = {} - per_dataset_results = {} - if epoch != -1: - if steps == -1: - out_txt = f" after epoch {epoch}" - else: - out_txt = f" in epoch {epoch} after {steps} steps" - else: - out_txt = "" - if self.truncate_dim is not None: - out_txt += f" (truncated to {self.truncate_dim})" - logger.info(f"NanoBEIR Evaluation of the model on {self.dataset_names} dataset{out_txt}:") - if self.score_functions is None: - self.score_functions = {model.similarity_fn_name: model.similarity} - self.score_function_names = [model.similarity_fn_name] - self._append_csv_headers(self.score_function_names) - - for evaluator in tqdm(self.evaluators, desc="Evaluating datasets", disable=not self.show_progress_bar): - logger.info(f"Evaluating {evaluator.name}") - evaluation = evaluator(model, output_path, epoch, steps) - for k in evaluation: - if self.truncate_dim: - dataset, _, metric = k.split("_", maxsplit=2) - else: - dataset, metric = k.split("_", maxsplit=1) - if metric not in per_metric_results: - per_metric_results[metric] = [] - per_dataset_results[dataset + "_" + metric] = evaluation[k] - per_metric_results[metric].append(evaluation[k]) - - agg_results = {} - for metric in per_metric_results: - agg_results[metric] = self.aggregate_fn(per_metric_results[metric]) - - if output_path is not None and self.write_csv: - csv_path = os.path.join(output_path, self.csv_file) - if not os.path.isfile(csv_path): - fOut = open(csv_path, mode="w", encoding="utf-8") - fOut.write(",".join(self.csv_headers)) - fOut.write("\n") - - else: - fOut = open(csv_path, mode="a", encoding="utf-8") - - output_data = [epoch, steps] - for name in self.score_function_names: - for k in self.accuracy_at_k: - output_data.append(agg_results[f"{name}_accuracy@{k}"]) - - for k in self.precision_recall_at_k: - output_data.append(agg_results[f"{name}_precision@{k}"]) - output_data.append(agg_results[f"{name}_recall@{k}"]) - - for k in self.mrr_at_k: - output_data.append(agg_results[f"{name}_mrr@{k}"]) - - for k in self.ndcg_at_k: - output_data.append(agg_results[f"{name}_ndcg@{k}"]) - - for k in self.map_at_k: - output_data.append(agg_results[f"{name}_map@{k}"]) - - fOut.write(",".join(map(str, output_data))) - fOut.write("\n") - fOut.close() - - if not self.primary_metric: - if self.main_score_function is None: - score_function = max( - [(name, agg_results[f"{name}_ndcg@{max(self.ndcg_at_k)}"]) for name in self.score_function_names], - key=lambda x: x[1], - )[0] - self.primary_metric = f"{score_function}_ndcg@{max(self.ndcg_at_k)}" - else: - self.primary_metric = f"{self.main_score_function.value}_ndcg@{max(self.ndcg_at_k)}" - - avg_queries = np.mean([len(evaluator.queries) for evaluator in self.evaluators]) - avg_corpus = np.mean([len(evaluator.corpus) for evaluator in self.evaluators]) - logger.info(f"Average Queries: {avg_queries}") - logger.info(f"Average Corpus: {avg_corpus}\n") - - for name in self.score_function_names: - logger.info(f"Aggregated for Score Function: {name}") - for k in self.accuracy_at_k: - logger.info("Accuracy@{}: {:.2f}%".format(k, agg_results[f"{name}_accuracy@{k}"] * 100)) - - for k in self.precision_recall_at_k: - logger.info("Precision@{}: {:.2f}%".format(k, agg_results[f"{name}_precision@{k}"] * 100)) - logger.info("Recall@{}: {:.2f}%".format(k, agg_results[f"{name}_recall@{k}"] * 100)) - - for k in self.mrr_at_k: - logger.info("MRR@{}: {:.4f}".format(k, agg_results[f"{name}_mrr@{k}"])) - - for k in self.ndcg_at_k: - logger.info("NDCG@{}: {:.4f}".format(k, agg_results[f"{name}_ndcg@{k}"])) - - # TODO: Ensure this primary_metric works as expected, also with bolding the right thing in the model card - agg_results = self.prefix_name_to_metrics(agg_results, self.name) - self.store_metrics_in_model_card_data(model, agg_results) - - per_dataset_results.update(agg_results) - - return per_dataset_results - - def _get_human_readable_name(self, dataset_name: DatasetNameType) -> str: - human_readable_name = f"Nano{dataset_name_to_human_readable[dataset_name.lower()]}" - if self.truncate_dim is not None: - human_readable_name += f"_{self.truncate_dim}" - return human_readable_name - def _load_dataset(self, dataset_name: DatasetNameType, **ir_evaluator_kwargs) -> PyLateInformationRetrievalEvaluator: if not is_datasets_available(): raise ValueError("datasets is not available. Please install it to use the NanoBEIREvaluator.") @@ -424,36 +188,4 @@ def _load_dataset(self, dataset_name: DatasetNameType, **ir_evaluator_kwargs) -> relevant_docs=qrels_dict, name=human_readable_name, **ir_evaluator_kwargs, - ) - - def _validate_dataset_names(self): - if len(self.dataset_names) == 0: - raise ValueError("dataset_names cannot be empty. Use None to evaluate on all datasets.") - if missing_datasets := [ - dataset_name for dataset_name in self.dataset_names if dataset_name.lower() not in dataset_name_to_id - ]: - raise ValueError( - f"Dataset(s) {missing_datasets} not found in the NanoBEIR collection. " - f"Valid dataset names are: {list(dataset_name_to_id.keys())}" - ) - - def _validate_prompts(self): - error_msg = "" - if self.query_prompts is not None: - if isinstance(self.query_prompts, str): - self.query_prompts = {dataset_name: self.query_prompts for dataset_name in self.dataset_names} - elif missing_query_prompts := [ - dataset_name for dataset_name in self.dataset_names if dataset_name not in self.query_prompts - ]: - error_msg += f"The following datasets are missing query prompts: {missing_query_prompts}\n" - - if self.corpus_prompts is not None: - if isinstance(self.corpus_prompts, str): - self.corpus_prompts = {dataset_name: self.corpus_prompts for dataset_name in self.dataset_names} - elif missing_corpus_prompts := [ - dataset_name for dataset_name in self.dataset_names if dataset_name not in self.corpus_prompts - ]: - error_msg += f"The following datasets are missing corpus prompts: {missing_corpus_prompts}\n" - - if error_msg: - raise ValueError(error_msg.strip()) \ No newline at end of file + ) \ No newline at end of file