diff --git a/docs/preset_configs.py b/docs/preset_configs.py index d8ac5997..9bd84f59 100644 --- a/docs/preset_configs.py +++ b/docs/preset_configs.py @@ -60,10 +60,7 @@ def _get_level_prefix(level: int) -> str: with mkdocs_gen_files.open(Path("preset_configs") / "index.md", "w") as fd: fd.write("# Preset Configs\n") fd.write( - "You can check the config using the following command:\n" - "```bash\n" - "flexeval_presets \n" - "```\n", + "You can check the config using the following command:\n```bash\nflexeval_presets \n```\n", ) fd.write(_nested_dict_to_markdown(all_pages)) diff --git a/examples/format_following/src/ifeval/combination.py b/examples/format_following/src/ifeval/combination.py index 6c71b24c..3d12a7ec 100644 --- a/examples/format_following/src/ifeval/combination.py +++ b/examples/format_following/src/ifeval/combination.py @@ -31,9 +31,7 @@ def __init__(self, prompt_to_repeat: str) -> None: self.prompt_to_repeat = prompt_to_repeat def check(self, response: str) -> bool: - if response.strip().lower().startswith(self.prompt_to_repeat.strip().lower()): - return True - return False + return response.strip().lower().startswith(self.prompt_to_repeat.strip().lower()) class TwoResponses(ResponseConstraint): diff --git a/examples/format_following/src/ifeval/length_constraints.py b/examples/format_following/src/ifeval/length_constraints.py index cb578db5..ca1da0d3 100644 --- a/examples/format_following/src/ifeval/length_constraints.py +++ b/examples/format_following/src/ifeval/length_constraints.py @@ -130,7 +130,7 @@ def check(self, response: str) -> bool: return False -@functools.lru_cache(maxsize=None) +@functools.lru_cache def _get_sentence_tokenizer() -> nltk.tokenize.punkt.PunktSentenceTokenizer: nltk.download("punkt_tab", quiet=True) return nltk.data.load("nltk:tokenizers/punkt/english.pickle") diff --git a/examples/format_following/src/metric/instruction_following_eval.py b/examples/format_following/src/metric/instruction_following_eval.py index 737fec2e..c2b3a161 100644 --- a/examples/format_following/src/metric/instruction_following_eval.py +++ b/examples/format_following/src/metric/instruction_following_eval.py @@ -36,11 +36,11 @@ def evaluate( ) -> MetricResult: is_check_passed_list: list[list[bool]] = [] is_check_passed_per_checker = defaultdict(list) - for lm_output, extra_info in zip(lm_outputs, extra_info_list): + for lm_output, extra_info in zip(lm_outputs, extra_info_list, strict=True): constraints = [self._instantiate_checker_from_params(params) for params in extra_info["constraints"]] is_check_passed = [checker.check(lm_output) for checker in constraints] is_check_passed_list.append(is_check_passed) - for checker, is_passed in zip(constraints, is_check_passed): + for checker, is_passed in zip(constraints, is_check_passed, strict=True): is_check_passed_per_checker[checker.__class__.__name__].append(is_passed) num_items = len(is_check_passed_list) diff --git a/examples/format_following/src/multilingual_ifeval/ja.py b/examples/format_following/src/multilingual_ifeval/ja.py index 43ab3e0f..c36e84a1 100644 --- a/examples/format_following/src/multilingual_ifeval/ja.py +++ b/examples/format_following/src/multilingual_ifeval/ja.py @@ -739,7 +739,7 @@ class KatakanaOnly(ResponseConstraint): def check(self, response: str) -> bool: def is_katakana(char: str) -> bool: - return "ァ" <= char <= "ン" or char == "ー" or char == "・" or "ヲ" <= char <= "゚" + return "ァ" <= char <= "ン" or char in {"ー", "・"} or "ヲ" <= char <= "゚" def is_ignorable(char: str) -> bool: return not unicodedata.category(char).startswith("L") diff --git a/flexeval/core/chat_dataset/base.py b/flexeval/core/chat_dataset/base.py index 747424a0..ec26eaba 100644 --- a/flexeval/core/chat_dataset/base.py +++ b/flexeval/core/chat_dataset/base.py @@ -2,8 +2,9 @@ import warnings from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Any, Sequence +from typing import Any @dataclass diff --git a/flexeval/core/chat_dataset/openai_messages.py b/flexeval/core/chat_dataset/openai_messages.py index aa743d01..22b41bf4 100644 --- a/flexeval/core/chat_dataset/openai_messages.py +++ b/flexeval/core/chat_dataset/openai_messages.py @@ -1,7 +1,8 @@ from __future__ import annotations import json -from typing import Any, Iterator +from collections.abc import Iterator +from typing import Any from .base import ChatDataset, ChatInstance diff --git a/flexeval/core/chat_dataset/template_based.py b/flexeval/core/chat_dataset/template_based.py index 27007c1a..55b6a3ee 100644 --- a/flexeval/core/chat_dataset/template_based.py +++ b/flexeval/core/chat_dataset/template_based.py @@ -8,7 +8,7 @@ import datasets from jinja2 import Template -from smart_open import open +from smart_open import open # noqa: A004 from flexeval.core.utils.jinja2_utils import JINJA2_ENV @@ -120,8 +120,7 @@ def __getitem__(self, i: int) -> ChatInstance: reference_list_string = self.reference_list_template.render(**item) if not (reference_list_string.startswith("[") and reference_list_string.endswith("]")): msg = ( - f"The reference_list_template should render a list of strings " - f"but we got `{reference_list_string}`." + f"The reference_list_template should render a list of strings but we got `{reference_list_string}`." ) raise ValueError(msg) reference_list.extend([str(ref) for ref in literal_eval(reference_list_string)]) diff --git a/flexeval/core/evaluate_from_data.py b/flexeval/core/evaluate_from_data.py index d4d9a54c..5e4cb46a 100644 --- a/flexeval/core/evaluate_from_data.py +++ b/flexeval/core/evaluate_from_data.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Iterable +from collections.abc import Iterable +from typing import Any from loguru import logger diff --git a/flexeval/core/evaluate_generation.py b/flexeval/core/evaluate_generation.py index c7fc58cb..b453f3c0 100644 --- a/flexeval/core/evaluate_generation.py +++ b/flexeval/core/evaluate_generation.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from loguru import logger from tqdm import tqdm diff --git a/flexeval/core/evaluate_multiple_choice.py b/flexeval/core/evaluate_multiple_choice.py index 7d81730d..5b32c90e 100644 --- a/flexeval/core/evaluate_multiple_choice.py +++ b/flexeval/core/evaluate_multiple_choice.py @@ -59,7 +59,7 @@ def evaluate_multiple_choice( if batch_id == 0: logger.info("Example of the model inputs and outputs:") logger.info(f"prefix: {batch_prefixes[0]}") - logger.info(f"choices: {batch_choices[:len(eval_instance.choices)]}") + logger.info(f"choices: {batch_choices[: len(eval_instance.choices)]}") batch_log_probs = language_model.compute_log_probs( text_list=batch_choices, diff --git a/flexeval/core/evaluate_perplexity.py b/flexeval/core/evaluate_perplexity.py index c3c6afa8..9fe868fd 100644 --- a/flexeval/core/evaluate_perplexity.py +++ b/flexeval/core/evaluate_perplexity.py @@ -2,7 +2,7 @@ import math from collections import defaultdict -from typing import Sequence +from collections.abc import Sequence from loguru import logger from tqdm import tqdm diff --git a/flexeval/core/evaluate_reward_model.py b/flexeval/core/evaluate_reward_model.py index 002e3261..bc51f8f3 100644 --- a/flexeval/core/evaluate_reward_model.py +++ b/flexeval/core/evaluate_reward_model.py @@ -1,7 +1,8 @@ from __future__ import annotations from collections import defaultdict -from typing import Any, Sequence +from collections.abc import Sequence +from typing import Any from loguru import logger from tqdm import tqdm diff --git a/flexeval/core/few_shot_generator/base.py b/flexeval/core/few_shot_generator/base.py index 9895dc4d..228d2bd8 100644 --- a/flexeval/core/few_shot_generator/base.py +++ b/flexeval/core/few_shot_generator/base.py @@ -1,14 +1,14 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Any, Union +from typing import Any from flexeval.core.chat_dataset import ChatDataset, ChatInstance from flexeval.core.generation_dataset import GenerationDataset, GenerationInstance from flexeval.core.multiple_choice_dataset import MultipleChoiceDataset, MultipleChoiceInstance -Dataset = Union[GenerationDataset, MultipleChoiceDataset, ChatDataset] -Instance = Union[GenerationInstance, MultipleChoiceInstance, ChatInstance] +Dataset = GenerationDataset | MultipleChoiceDataset | ChatDataset +Instance = GenerationInstance | MultipleChoiceInstance | ChatInstance class FewShotGenerator(ABC): diff --git a/flexeval/core/generation_dataset/base.py b/flexeval/core/generation_dataset/base.py index f63ed4cc..949f8743 100644 --- a/flexeval/core/generation_dataset/base.py +++ b/flexeval/core/generation_dataset/base.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Any, Sequence +from typing import Any @dataclass diff --git a/flexeval/core/generation_dataset/template_based.py b/flexeval/core/generation_dataset/template_based.py index a447bf10..63a1ba88 100644 --- a/flexeval/core/generation_dataset/template_based.py +++ b/flexeval/core/generation_dataset/template_based.py @@ -6,7 +6,7 @@ import datasets from jinja2 import Template -from smart_open import open +from smart_open import open # noqa: A004 from flexeval.core.utils.jinja2_utils import JINJA2_ENV @@ -82,8 +82,7 @@ def __getitem__(self, i: int) -> GenerationInstance: reference_list_string = self.reference_list_template.render(**item) if not (reference_list_string.startswith("[") and reference_list_string.endswith("]")): msg = ( - f"The reference_list_template should render a list of strings " - f"but we got `{reference_list_string}`." + f"The reference_list_template should render a list of strings but we got `{reference_list_string}`." ) raise ValueError(msg) reference_list.extend([str(ref) for ref in literal_eval(reference_list_string)]) diff --git a/flexeval/core/language_model/hf_lm.py b/flexeval/core/language_model/hf_lm.py index ae8e09fb..a59ea2e1 100644 --- a/flexeval/core/language_model/hf_lm.py +++ b/flexeval/core/language_model/hf_lm.py @@ -4,7 +4,8 @@ import copy import gc import json -from typing import Any, Callable, Literal, TypeVar +from collections.abc import Callable +from typing import Any, Literal, TypeVar import torch import torch.nn.functional as F # noqa: N812 diff --git a/flexeval/core/language_model/openai_api.py b/flexeval/core/language_model/openai_api.py index cbf770ec..b7b8ad51 100644 --- a/flexeval/core/language_model/openai_api.py +++ b/flexeval/core/language_model/openai_api.py @@ -3,8 +3,9 @@ import itertools import os import time +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import openai import tiktoken diff --git a/flexeval/core/language_model/openai_batch_api.py b/flexeval/core/language_model/openai_batch_api.py index 5de202f5..0ca208b0 100644 --- a/flexeval/core/language_model/openai_batch_api.py +++ b/flexeval/core/language_model/openai_batch_api.py @@ -84,7 +84,7 @@ def __init__( # convert the flexeval-specific argument name to the OpenAI-specific name if "max_new_tokens" in self.default_gen_kwargs: self.default_gen_kwargs["max_completion_tokens"] = self.default_gen_kwargs.pop("max_new_tokens") - self.temp_jsonl_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl") + self.temp_jsonl_file = tempfile.NamedTemporaryFile(delete=False, suffix=".jsonl") # noqa: SIM115 self.polling_interval_seconds = polling_interval_seconds self.developer_message = developer_message @@ -146,7 +146,7 @@ async def _post_batch_requests( self.create_batch_file(custom_id_2_input, **gen_kwargs) # Update batch file - with open(self.temp_jsonl_file.name, "rb") as batch_file: # noqa: ASYNC101 + with open(self.temp_jsonl_file.name, "rb") as batch_file: # noqa: ASYNC230 batch_input_file = await self._client.files.create(file=batch_file, purpose="batch") # Run Job @@ -190,7 +190,7 @@ def _execute_batch_requests( # noqa: C901 for messages, tools in zip(messages_list, tools_list) } # The response will be an empty string if the API produces an error. - custom_id_2_response: dict[str, str | list[dict[str, Any]]] = {custom_id: "" for custom_id in custom_id_2_input} + custom_id_2_response: dict[str, str | list[dict[str, Any]]] = dict.fromkeys(custom_id_2_input, "") exec_cnt = 1 while len(custom_id_2_input) > 0: diff --git a/flexeval/core/language_model/vllm_model.py b/flexeval/core/language_model/vllm_model.py index aa9cb245..7dfadfec 100644 --- a/flexeval/core/language_model/vllm_model.py +++ b/flexeval/core/language_model/vllm_model.py @@ -1,7 +1,8 @@ from __future__ import annotations import time -from typing import TYPE_CHECKING, Any, Callable, Literal +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, Literal import torch from loguru import logger diff --git a/flexeval/core/language_model/vllm_serve_lm.py b/flexeval/core/language_model/vllm_serve_lm.py index 524bd914..7bbd09c7 100644 --- a/flexeval/core/language_model/vllm_serve_lm.py +++ b/flexeval/core/language_model/vllm_serve_lm.py @@ -6,7 +6,8 @@ import subprocess import threading import time -from typing import IO, Any, Callable +from collections.abc import Callable +from typing import IO, Any import requests import torch diff --git a/flexeval/core/metric/llm_label.py b/flexeval/core/metric/llm_label.py index 030a993e..91028320 100644 --- a/flexeval/core/metric/llm_label.py +++ b/flexeval/core/metric/llm_label.py @@ -79,7 +79,7 @@ def summarize_evaluator_labels( category2mean_score[category] = score category2dist[category] = calc_label_dist(valid_labels, label_names) - for category in category2mean_score: + for category in category2mean_score: # noqa: PLC0206 summary[f"{score_key}/{category}"] = category2mean_score[category] summary[f"{dist_key}/{category}"] = category2dist[category] diff --git a/flexeval/core/metric/perspective_api.py b/flexeval/core/metric/perspective_api.py index ddc08991..73b7ad4f 100644 --- a/flexeval/core/metric/perspective_api.py +++ b/flexeval/core/metric/perspective_api.py @@ -2,7 +2,8 @@ import os import time -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import numpy as np from googleapiclient import discovery @@ -83,7 +84,7 @@ def evaluate( instance_details = [] for lm_output in lm_outputs: if lm_output == "": - instance_details.append({att: 0.0 for att in self.attributes}) + instance_details.append(dict.fromkeys(self.attributes, 0.0)) continue analyze_request = { "comment": {"text": lm_output}, diff --git a/flexeval/core/metric/substring_match.py b/flexeval/core/metric/substring_match.py index b3f96252..01fbaf92 100644 --- a/flexeval/core/metric/substring_match.py +++ b/flexeval/core/metric/substring_match.py @@ -59,7 +59,7 @@ def evaluate( ] score = 0.0 - if len(match_list): + if match_list: score = sum(match_list) / len(match_list) summary = {f"substring_match-{self.mode}": score} diff --git a/flexeval/core/multiple_choice_dataset/base.py b/flexeval/core/multiple_choice_dataset/base.py index 611cd6f1..5182f41e 100644 --- a/flexeval/core/multiple_choice_dataset/base.py +++ b/flexeval/core/multiple_choice_dataset/base.py @@ -1,8 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass -from typing import Sequence @dataclass diff --git a/flexeval/core/multiple_choice_dataset/template_based.py b/flexeval/core/multiple_choice_dataset/template_based.py index e5ec1a54..66b623bb 100644 --- a/flexeval/core/multiple_choice_dataset/template_based.py +++ b/flexeval/core/multiple_choice_dataset/template_based.py @@ -78,7 +78,7 @@ def __getitem__(self, i: int) -> MultipleChoiceInstance: answer_index = int(self.answer_index_template.render(**item)) if not (answer_index >= 0 and answer_index < len(choices)): - msg = f"at least {answer_index+1} choices required, but got {choices}" + msg = f"at least {answer_index + 1} choices required, but got {choices}" raise ValueError(msg) return MultipleChoiceInstance( diff --git a/flexeval/core/pairwise_comparison/match_maker/all_combinations.py b/flexeval/core/pairwise_comparison/match_maker/all_combinations.py index f177080a..8c23d92a 100644 --- a/flexeval/core/pairwise_comparison/match_maker/all_combinations.py +++ b/flexeval/core/pairwise_comparison/match_maker/all_combinations.py @@ -1,7 +1,7 @@ from __future__ import annotations import itertools -from typing import Iterable +from collections.abc import Iterable from flexeval.core.pairwise_comparison.match import Match diff --git a/flexeval/core/pairwise_comparison/match_maker/base.py b/flexeval/core/pairwise_comparison/match_maker/base.py index 2f381961..57bb656c 100644 --- a/flexeval/core/pairwise_comparison/match_maker/base.py +++ b/flexeval/core/pairwise_comparison/match_maker/base.py @@ -1,7 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Iterable, TypeVar +from collections.abc import Iterable +from typing import TypeVar from flexeval.core.pairwise_comparison.match import Match diff --git a/flexeval/core/pairwise_comparison/match_maker/random_combinations.py b/flexeval/core/pairwise_comparison/match_maker/random_combinations.py index 7ddacd17..ce781814 100644 --- a/flexeval/core/pairwise_comparison/match_maker/random_combinations.py +++ b/flexeval/core/pairwise_comparison/match_maker/random_combinations.py @@ -2,7 +2,7 @@ import itertools import random -from typing import Iterable +from collections.abc import Iterable from flexeval.core.pairwise_comparison.match import Match @@ -25,7 +25,7 @@ def generate_matches( cached_matches = cached_matches or [] cache_dict = {match.get_key_for_cache(): match for match in cached_matches} - model_match_counter: dict[str, int] = {name: 0 for name in model_names} + model_match_counter: dict[str, int] = dict.fromkeys(model_names, 0) possible_new_matches: list[Match] = [] matches: list[Match] = [] for m1, m2 in all_permutations: diff --git a/flexeval/core/pairwise_comparison/scorer/win_rate.py b/flexeval/core/pairwise_comparison/scorer/win_rate.py index f46903f3..6599adcd 100644 --- a/flexeval/core/pairwise_comparison/scorer/win_rate.py +++ b/flexeval/core/pairwise_comparison/scorer/win_rate.py @@ -30,7 +30,7 @@ def compute_scores( win_count_dict[model2] += 0.5 win_rate_dict = {} - for model in match_count_dict: + for model in match_count_dict: # noqa: PLC0206 win_rate_dict[model] = 100 * win_count_dict.get(model, 0.0) / match_count_dict[model] return dict(sorted(win_rate_dict.items(), key=lambda x: -x[1])) diff --git a/flexeval/core/reward_bench_dataset/base.py b/flexeval/core/reward_bench_dataset/base.py index ef77de44..894318a0 100644 --- a/flexeval/core/reward_bench_dataset/base.py +++ b/flexeval/core/reward_bench_dataset/base.py @@ -1,8 +1,9 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass, field -from typing import Any, Sequence +from typing import Any @dataclass diff --git a/flexeval/core/reward_bench_dataset/template_based.py b/flexeval/core/reward_bench_dataset/template_based.py index 6ae0e6f6..3d505479 100644 --- a/flexeval/core/reward_bench_dataset/template_based.py +++ b/flexeval/core/reward_bench_dataset/template_based.py @@ -5,7 +5,7 @@ import datasets from jinja2 import Template -from smart_open import open +from smart_open import open # noqa: A004 from flexeval.core.utils.jinja2_utils import JINJA2_ENV diff --git a/flexeval/core/reward_model/pairwise_judge_reward_model.py b/flexeval/core/reward_model/pairwise_judge_reward_model.py index 3a10cb55..fca13aba 100644 --- a/flexeval/core/reward_model/pairwise_judge_reward_model.py +++ b/flexeval/core/reward_model/pairwise_judge_reward_model.py @@ -35,9 +35,7 @@ def evaluate_model_output(model_output: str, gold_label: PairwiseChoice) -> bool return False # If only gold label is in model output, then output is **correct** - if gold_label.value in model_output: - return True - return False + return gold_label.value in model_output def aggregate_judge_results( diff --git a/flexeval/core/text_dataset/base.py b/flexeval/core/text_dataset/base.py index 21d9b8d3..2d0823fe 100644 --- a/flexeval/core/text_dataset/base.py +++ b/flexeval/core/text_dataset/base.py @@ -1,8 +1,8 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Sequence from dataclasses import dataclass -from typing import Sequence @dataclass diff --git a/flexeval/core/text_dataset/jsonl.py b/flexeval/core/text_dataset/jsonl.py index f79fc155..ea47a747 100644 --- a/flexeval/core/text_dataset/jsonl.py +++ b/flexeval/core/text_dataset/jsonl.py @@ -3,7 +3,7 @@ import json from os import PathLike -from smart_open import open +from smart_open import open # noqa: A004 from .base import TextDataset, TextInstance diff --git a/flexeval/core/utils/data_util.py b/flexeval/core/utils/data_util.py index 7f8fcac7..0eb5a2e3 100644 --- a/flexeval/core/utils/data_util.py +++ b/flexeval/core/utils/data_util.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Iterable, Iterator, TypeVar +from collections.abc import Iterable, Iterator +from typing import TypeVar T = TypeVar("T") diff --git a/flexeval/scripts/flexeval_file.py b/flexeval/scripts/flexeval_file.py index b14c1abf..0104e168 100644 --- a/flexeval/scripts/flexeval_file.py +++ b/flexeval/scripts/flexeval_file.py @@ -6,7 +6,7 @@ from abc import ABC, abstractmethod from importlib.metadata import version from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any import _jsonnet from jsonargparse import ActionConfigFile, ArgumentParser @@ -69,13 +69,13 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915 ) parser.add_argument( "--metrics", - type=Union[List[Metric], Metric], + type=list[Metric] | Metric, required=True, help="You can specify the parameters, the path to the config file, or the name of the preset config.", ) parser.add_argument( "--eval_dataset", - type=Union[GenerationDataset, ChatDataset], + type=GenerationDataset | ChatDataset, default=None, help="If specified, override the references with the ones from the generation_dataset.", ) @@ -108,7 +108,7 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915 # Metadata parser.add_argument( "--metadata", - type=Dict[str, Any], + type=dict[str, Any], default={}, help="Metadata to save in config.json", ) diff --git a/flexeval/scripts/flexeval_lm.py b/flexeval/scripts/flexeval_lm.py index dbfc3751..8e7c588e 100644 --- a/flexeval/scripts/flexeval_lm.py +++ b/flexeval/scripts/flexeval_lm.py @@ -10,7 +10,7 @@ from collections import defaultdict from importlib.metadata import version from pathlib import Path -from typing import Any, Dict +from typing import Any import _jsonnet from jsonargparse import ActionConfigFile, ArgumentParser, Namespace @@ -121,7 +121,7 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915 ) parser.add_argument( "--eval_setups", - type=Dict[str, EvalSetup], + type=dict[str, EvalSetup], help="A dictionary of evaluation setups. " "The key is the folder name where the outputs will be saved, and the value is the EvalSetup object. ", enable_path=True, @@ -160,7 +160,7 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915 # Metadata parser.add_argument( "--metadata", - type=Dict[str, Any], + type=dict[str, Any], default={}, help="Metadata to save in config.json", ) @@ -204,7 +204,7 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915 overrides_for_eval_setups[setup_name][override_key] = sys.argv[i + 1] indices_to_pop += [i, i + 1] sys.argv = [a for i, a in enumerate(sys.argv) if i not in indices_to_pop] - for eval_key in params_for_eval_setups: + for eval_key in params_for_eval_setups: # noqa: PLC0206 for override_key, override_value in overrides_for_eval_setups[eval_key].items(): override_jsonargparse_params(params_for_eval_setups[eval_key], override_key, override_value) for eval_key, eval_config in params_for_eval_setups.items(): diff --git a/flexeval/scripts/flexeval_pairwise.py b/flexeval/scripts/flexeval_pairwise.py index 21a85baf..a36cfe01 100644 --- a/flexeval/scripts/flexeval_pairwise.py +++ b/flexeval/scripts/flexeval_pairwise.py @@ -4,7 +4,7 @@ import sys from importlib.metadata import version from pathlib import Path -from typing import Any, Dict, List +from typing import Any import _jsonnet from jsonargparse import ActionConfigFile, ArgumentParser @@ -22,12 +22,12 @@ def main() -> None: parser = ArgumentParser(parser_mode="jsonnet") - parser.add_argument("--lm_output_paths", type=Dict[str, str], required=True) + parser.add_argument("--lm_output_paths", type=dict[str, str], required=True) parser.add_argument("--judge", type=PairwiseJudge, required=True, enable_path=True) parser.add_argument("--match_maker", type=MatchMaker, default={"class_path": "AllCombinations"}) parser.add_argument( "--scorers", - type=List[PairwiseScorer], + type=list[PairwiseScorer], default=[{"class_path": "WinRateScorer"}, {"class_path": "BradleyTerryScorer"}], ) parser.add_argument("--batch_size", type=int, default=4) @@ -65,7 +65,7 @@ def main() -> None: # Metadata parser.add_argument( "--metadata", - type=Dict[str, Any], + type=dict[str, Any], default={}, help="Metadata to save in config.json", ) diff --git a/flexeval/scripts/flexeval_reward.py b/flexeval/scripts/flexeval_reward.py index b45ccaa5..f7850107 100644 --- a/flexeval/scripts/flexeval_reward.py +++ b/flexeval/scripts/flexeval_reward.py @@ -4,7 +4,7 @@ import sys from importlib.metadata import version from pathlib import Path -from typing import Any, Dict +from typing import Any from jsonargparse import ActionConfigFile, ArgumentParser from loguru import logger @@ -61,7 +61,7 @@ def main() -> None: # Metadata parser.add_argument( "--metadata", - type=Dict[str, Any], + type=dict[str, Any], default={}, help="Metadata to save in config.json", ) diff --git a/flexeval/utils/hf_utils.py b/flexeval/utils/hf_utils.py index f3d0c906..eb006cc9 100644 --- a/flexeval/utils/hf_utils.py +++ b/flexeval/utils/hf_utils.py @@ -47,8 +47,7 @@ def get_default_model_kwargs(model_kwargs: None | dict[str, Any] = None) -> dict # Convert string to torch.dtype # We allow either "bfloat16" or "torch.bfloat16" torch_dtype_str = model_kwargs["torch_dtype"] - if torch_dtype_str.startswith("torch."): - torch_dtype_str = torch_dtype_str[len("torch.") :] + torch_dtype_str = torch_dtype_str.removeprefix("torch.") model_kwargs["torch_dtype"] = getattr(torch, torch_dtype_str) if not isinstance(model_kwargs["torch_dtype"], torch.dtype): msg = f"Invalid torch_dtype: {model_kwargs['torch_dtype']}" diff --git a/poetry.lock b/poetry.lock index 85e90c59..7a2e47f3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. [[package]] name = "accelerate" @@ -6185,29 +6185,31 @@ pyasn1 = ">=0.1.3" [[package]] name = "ruff" -version = "0.4.10" +version = "0.14.7" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" groups = ["dev"] files = [ - {file = "ruff-0.4.10-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:5c2c4d0859305ac5a16310eec40e4e9a9dec5dcdfbe92697acd99624e8638dac"}, - {file = "ruff-0.4.10-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a79489607d1495685cdd911a323a35871abfb7a95d4f98fc6f85e799227ac46e"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b1dd1681dfa90a41b8376a61af05cc4dc5ff32c8f14f5fe20dba9ff5deb80cd6"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c75c53bb79d71310dc79fb69eb4902fba804a81f374bc86a9b117a8d077a1784"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:18238c80ee3d9100d3535d8eb15a59c4a0753b45cc55f8bf38f38d6a597b9739"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:d8f71885bce242da344989cae08e263de29752f094233f932d4f5cfb4ef36a81"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:330421543bd3222cdfec481e8ff3460e8702ed1e58b494cf9d9e4bf90db52b9d"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e9b6fb3a37b772628415b00c4fc892f97954275394ed611056a4b8a2631365e"}, - {file = "ruff-0.4.10-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f54c481b39a762d48f64d97351048e842861c6662d63ec599f67d515cb417f6"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:67fe086b433b965c22de0b4259ddfe6fa541c95bf418499bedb9ad5fb8d1c631"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:acfaaab59543382085f9eb51f8e87bac26bf96b164839955f244d07125a982ef"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_i686.whl", hash = "sha256:3cea07079962b2941244191569cf3a05541477286f5cafea638cd3aa94b56815"}, - {file = "ruff-0.4.10-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:338a64ef0748f8c3a80d7f05785930f7965d71ca260904a9321d13be24b79695"}, - {file = "ruff-0.4.10-py3-none-win32.whl", hash = "sha256:ffe3cd2f89cb54561c62e5fa20e8f182c0a444934bf430515a4b422f1ab7b7ca"}, - {file = "ruff-0.4.10-py3-none-win_amd64.whl", hash = "sha256:67f67cef43c55ffc8cc59e8e0b97e9e60b4837c8f21e8ab5ffd5d66e196e25f7"}, - {file = "ruff-0.4.10-py3-none-win_arm64.whl", hash = "sha256:dd1fcee327c20addac7916ca4e2653fbbf2e8388d8a6477ce5b4e986b68ae6c0"}, - {file = "ruff-0.4.10.tar.gz", hash = "sha256:3aa4f2bc388a30d346c56524f7cacca85945ba124945fe489952aadb6b5cd804"}, + {file = "ruff-0.14.7-py3-none-linux_armv6l.whl", hash = "sha256:b9d5cb5a176c7236892ad7224bc1e63902e4842c460a0b5210701b13e3de4fca"}, + {file = "ruff-0.14.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:3f64fe375aefaf36ca7d7250292141e39b4cea8250427482ae779a2aa5d90015"}, + {file = "ruff-0.14.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:93e83bd3a9e1a3bda64cb771c0d47cda0e0d148165013ae2d3554d718632d554"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3838948e3facc59a6070795de2ae16e5786861850f78d5914a03f12659e88f94"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:24c8487194d38b6d71cd0fd17a5b6715cda29f59baca1defe1e3a03240f851d1"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:79c73db6833f058a4be8ffe4a0913b6d4ad41f6324745179bd2aa09275b01d0b"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:12eb7014fccff10fc62d15c79d8a6be4d0c2d60fe3f8e4d169a0d2def75f5dad"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6c623bbdc902de7ff715a93fa3bb377a4e42dd696937bf95669118773dbf0c50"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f53accc02ed2d200fa621593cdb3c1ae06aa9b2c3cae70bc96f72f0000ae97a9"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:281f0e61a23fcdcffca210591f0f53aafaa15f9025b5b3f9706879aaa8683bc4"}, + {file = "ruff-0.14.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:dbbaa5e14148965b91cb090236931182ee522a5fac9bc5575bafc5c07b9f9682"}, + {file = "ruff-0.14.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:1464b6e54880c0fe2f2d6eaefb6db15373331414eddf89d6b903767ae2458143"}, + {file = "ruff-0.14.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:f217ed871e4621ea6128460df57b19ce0580606c23aeab50f5de425d05226784"}, + {file = "ruff-0.14.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:6be02e849440ed3602d2eb478ff7ff07d53e3758f7948a2a598829660988619e"}, + {file = "ruff-0.14.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:19a0f116ee5e2b468dfe80c41c84e2bbd6b74f7b719bee86c2ecde0a34563bcc"}, + {file = "ruff-0.14.7-py3-none-win32.whl", hash = "sha256:e33052c9199b347c8937937163b9b149ef6ab2e4bb37b042e593da2e6f6cccfa"}, + {file = "ruff-0.14.7-py3-none-win_amd64.whl", hash = "sha256:e17a20ad0d3fad47a326d773a042b924d3ac31c6ca6deb6c72e9e6b5f661a7c6"}, + {file = "ruff-0.14.7-py3-none-win_arm64.whl", hash = "sha256:be4d653d3bea1b19742fcc6502354e32f65cd61ff2fbdb365803ef2c2aec6228"}, + {file = "ruff-0.14.7.tar.gz", hash = "sha256:3417deb75d23bd14a722b57b0a1435561db65f0ad97435b4cf9f85ffcef34ae5"}, ] [[package]] @@ -8518,4 +8520,4 @@ wandb = ["wandb"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.13" -content-hash = "f9d97a4594fcc05e67206bc48e8abaf18d5396448b355246de7659a09eb430a6" +content-hash = "7c3ae4213c07a0b65a60f0211f216c0d0c537b8afeef172624a8999f832aad6c" diff --git a/pyproject.toml b/pyproject.toml index 83403c0d..027170b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ wandb = ["wandb"] pytest = "^7.4.3" pytest-mock = "^3.14.0" taskipy = "^1.12.0" -ruff = "^0.4.5" +ruff = "^0.14.7" pytest-xdist = "^3.7.0" [tool.poetry.group.docs.dependencies] @@ -101,7 +101,6 @@ ignore = [ "G", # flake8-logging-format "TD", # flake8-todos "ANN003", # missing-type-kwargs - "ANN101", # missing-type-self "ARG002", # unused-method-argument "FIX002", # line-contains-todo "PLR0913", # too-many-arguments @@ -112,9 +111,11 @@ ignore = [ "RUF002", # ambiguous-unicode-character-docstring "S311", # suspicious-non-cryptographic-random-usage "S603", # subprocess-without-shell-equals-true - "TCH001", # typing-only-first-party-import - "TCH002", # typing-only-third-party-import - "TCH003", # typing-only-standard-library-import + "TC001", # typing-only-first-party-import + "TC002", # typing-only-third-party-import + "TC003", # typing-only-standard-library-import + "PLC0415", # `import` should be at the top-level of a file + "B905", # `zip()` without an explicit `strict=` parameter ] [tool.ruff.lint.per-file-ignores] diff --git a/tests/conftest.py b/tests/conftest.py index bdabdf41..48bc285d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ -from typing import Any, Generator +from collections.abc import Generator +from typing import Any import pytest from loguru import logger @@ -14,7 +15,7 @@ def is_vllm_enabled() -> bool: return False -@pytest.fixture() +@pytest.fixture def caplog(caplog: pytest.LogCaptureFixture) -> Generator[pytest.LogCaptureFixture, Any, None]: handler_id = logger.add(caplog.handler, format="{message}") yield caplog diff --git a/tests/core/chat_dataset/test_openai_messages.py b/tests/core/chat_dataset/test_openai_messages.py index bca26539..a1360873 100644 --- a/tests/core/chat_dataset/test_openai_messages.py +++ b/tests/core/chat_dataset/test_openai_messages.py @@ -2,8 +2,9 @@ import json import tempfile +from collections.abc import Callable from copy import deepcopy -from typing import Any, Callable +from typing import Any import pytest @@ -19,7 +20,7 @@ ] -@pytest.fixture() +@pytest.fixture def jsonl_data_factory(tmp_path) -> Callable: # noqa: ANN001 def _create( message_key: str, messages_list: list[dict], num_samples: int = 10, extra_info: dict | None = None @@ -160,7 +161,7 @@ def test_load_dataset_with_extra_info(jsonl_data_factory) -> None: # noqa: ANN0 ] -@pytest.fixture() +@pytest.fixture def mock_chat_messages_with_tools_data_path() -> None: with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl") as f: for messages in TEST_CHAT_MESSAGES_WITH_TOOLS: diff --git a/tests/core/chat_dataset/test_template_based.py b/tests/core/chat_dataset/test_template_based.py index 47308dfa..942eea2d 100644 --- a/tests/core/chat_dataset/test_template_based.py +++ b/tests/core/chat_dataset/test_template_based.py @@ -203,7 +203,7 @@ def test_remove_conditions( assert len(item.references) > 1 -@pytest.fixture() +@pytest.fixture def dummy_template_file(tmp_path: Path) -> Path: template_content = "Hello {{ name }}!" template_file = tmp_path / "dummy.j2" diff --git a/tests/core/language_model/base.py b/tests/core/language_model/base.py index cea48b4b..37da5f20 100644 --- a/tests/core/language_model/base.py +++ b/tests/core/language_model/base.py @@ -27,7 +27,7 @@ def model(self): ``` """ - @pytest.fixture() + @pytest.fixture @abstractmethod def lm(self, *args, **kwargs) -> LanguageModel: # noqa: ANN002 """Return an instance of the LanguageModel. @@ -36,7 +36,7 @@ def lm(self, *args, **kwargs) -> LanguageModel: # noqa: ANN002 msg = "Subclasses must implement model fixture" raise NotImplementedError(msg) - @pytest.fixture() + @pytest.fixture @abstractmethod def chat_lm(self, *args, **kwargs) -> LanguageModel: # noqa: ANN002 """Return an instance of the LanguageModel. @@ -45,7 +45,7 @@ def chat_lm(self, *args, **kwargs) -> LanguageModel: # noqa: ANN002 msg = "Subclasses must implement model fixture" raise NotImplementedError(msg) - @pytest.fixture() + @pytest.fixture @abstractmethod def chat_lm_for_tool_calling(self, *args, **kwargs) -> LanguageModel: # noqa: ANN002 """Return an instance of the LanguageModel. @@ -193,7 +193,7 @@ def test_generate_chat_response_if_number_of_tools_and_messages_not_equal( ] try: with pytest.raises( - ValueError, match="tools_list must be either None or a list of the same length as chat_messages_list." + ValueError, match=r"tools_list must be either None or a list of the same length as chat_messages_list." ): chat_lm_for_tool_calling.generate_chat_response( [ diff --git a/tests/core/language_model/test_hf_lm.py b/tests/core/language_model/test_hf_lm.py index 591285fc..8b9cb5d1 100644 --- a/tests/core/language_model/test_hf_lm.py +++ b/tests/core/language_model/test_hf_lm.py @@ -2,7 +2,8 @@ import functools import logging -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from unittest.mock import patch import pytest @@ -87,15 +88,15 @@ def chat_lm_without_system_message(model_name: str = "sbintuitions/tiny-lm-chat" class TestHuggingFaceLM(BaseLanguageModelTest): - @pytest.fixture() + @pytest.fixture def lm(self, lm: HuggingFaceLM) -> LanguageModel: return lm - @pytest.fixture() + @pytest.fixture def chat_lm(self, chat_lm: HuggingFaceLM) -> LanguageModel: return chat_lm - @pytest.fixture() + @pytest.fixture def chat_lm_for_tool_calling(self, chat_lm_for_tool_calling: HuggingFaceLM) -> HuggingFaceLM: return chat_lm_for_tool_calling diff --git a/tests/core/language_model/test_litellm_api.py b/tests/core/language_model/test_litellm_api.py index 49dde7dc..67295ef3 100644 --- a/tests/core/language_model/test_litellm_api.py +++ b/tests/core/language_model/test_litellm_api.py @@ -23,7 +23,7 @@ def chat_lm() -> LiteLLMChatAPI: @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set") class TestLiteLLMChatAPI(BaseLanguageModelTest): - @pytest.fixture() + @pytest.fixture def lm(self) -> LanguageModel: return LiteLLMChatAPI( "gpt-4o-mini-2024-07-18", @@ -33,11 +33,11 @@ def lm(self) -> LanguageModel: "Do not provide the answer or any other information.", ) - @pytest.fixture() + @pytest.fixture def chat_lm(self, chat_lm: LiteLLMChatAPI) -> LanguageModel: return chat_lm - @pytest.fixture() + @pytest.fixture def chat_lm_for_tool_calling(self, chat_lm: LiteLLMChatAPI) -> LiteLLMChatAPI: return chat_lm diff --git a/tests/core/language_model/test_openai_api.py b/tests/core/language_model/test_openai_api.py index 07e0d7cd..cd21e438 100644 --- a/tests/core/language_model/test_openai_api.py +++ b/tests/core/language_model/test_openai_api.py @@ -26,7 +26,7 @@ def chat_lm() -> OpenAIChatAPI: @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set") class TestOpenAIChatAPI(BaseLanguageModelTest): - @pytest.fixture() + @pytest.fixture def lm(self) -> LanguageModel: return OpenAIChatAPI( "gpt-4o-mini-2024-07-18", @@ -36,11 +36,11 @@ def lm(self) -> LanguageModel: "Do not provide the answer or any other information.", ) - @pytest.fixture() + @pytest.fixture def chat_lm(self, chat_lm: OpenAIChatAPI) -> LanguageModel: return chat_lm - @pytest.fixture() + @pytest.fixture def chat_lm_for_tool_calling(self, chat_lm: OpenAIChatAPI) -> OpenAIChatAPI: return chat_lm diff --git a/tests/core/language_model/test_openai_batch_api.py b/tests/core/language_model/test_openai_batch_api.py index bb06e291..16e7d655 100644 --- a/tests/core/language_model/test_openai_batch_api.py +++ b/tests/core/language_model/test_openai_batch_api.py @@ -20,9 +20,9 @@ def chat_lm() -> OpenAIChatBatchAPI: @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set") -@pytest.mark.batch_api() +@pytest.mark.batch_api class TestOpenAIChatBatchAPI(BaseLanguageModelTest): - @pytest.fixture() + @pytest.fixture def lm(self) -> LanguageModel: return OpenAIChatBatchAPI( "gpt-4o-mini-2024-07-18", @@ -32,11 +32,11 @@ def lm(self) -> LanguageModel: "Do not provide the answer or any other information.", ) - @pytest.fixture() + @pytest.fixture def chat_lm(self, chat_lm: OpenAIChatBatchAPI) -> LanguageModel: return chat_lm - @pytest.fixture() + @pytest.fixture def chat_lm_for_tool_calling(self, chat_lm: OpenAIChatBatchAPI) -> OpenAIChatBatchAPI: return chat_lm @@ -50,7 +50,7 @@ def test_batch_chat_response_is_not_affected_by_batch(self, chat_lm: LanguageMod @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") -@pytest.mark.batch_api() +@pytest.mark.batch_api def test_create_batch_file(chat_lm: OpenAIChatBatchAPI) -> None: chat_lm.create_batch_file( {str(i): {"messages": [[{"role": "user", "content": "こんにちは。"}]], "tools": None} for i in range(10)}, @@ -63,7 +63,7 @@ def test_create_batch_file(chat_lm: OpenAIChatBatchAPI) -> None: @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") -@pytest.mark.batch_api() +@pytest.mark.batch_api def test_warning_if_conflict_max_new_tokens(caplog: pytest.LogCaptureFixture) -> None: caplog.set_level(logging.WARNING) chat_lm_with_max_new_tokens = OpenAIChatBatchAPI( @@ -77,7 +77,7 @@ def test_warning_if_conflict_max_new_tokens(caplog: pytest.LogCaptureFixture) -> @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") -@pytest.mark.batch_api() +@pytest.mark.batch_api def test_compute_chat_log_probs_for_multi_tokens(chat_lm: OpenAIChatBatchAPI) -> None: prompt = [{"role": "user", "content": "Hello."}] response = {"role": "assistant", "content": "Hello~~~"} @@ -86,7 +86,7 @@ def test_compute_chat_log_probs_for_multi_tokens(chat_lm: OpenAIChatBatchAPI) -> @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") -@pytest.mark.batch_api() +@pytest.mark.batch_api def test_developer_message() -> None: openai_api = OpenAIChatBatchAPI( "gpt-4o-mini-2024-07-18", diff --git a/tests/core/language_model/test_openai_completion_api.py b/tests/core/language_model/test_openai_completion_api.py index 70858609..f63d2ef2 100644 --- a/tests/core/language_model/test_openai_completion_api.py +++ b/tests/core/language_model/test_openai_completion_api.py @@ -20,15 +20,15 @@ def lm() -> OpenAICompletionAPI: @pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI API Key is not set") class TestOpenAICompletionAPI(BaseLanguageModelTest): - @pytest.fixture() + @pytest.fixture def lm(self, lm: OpenAICompletionAPI) -> LanguageModel: return lm - @pytest.fixture() + @pytest.fixture def chat_lm(self, lm: OpenAICompletionAPI) -> LanguageModel: return lm - @pytest.fixture() + @pytest.fixture def chat_lm_for_tool_calling(self, lm: OpenAICompletionAPI) -> LanguageModel: return lm diff --git a/tests/core/language_model/vllm/test_vllm_common.py b/tests/core/language_model/vllm/test_vllm_common.py index d6c32e1b..d96bbcad 100644 --- a/tests/core/language_model/vllm/test_vllm_common.py +++ b/tests/core/language_model/vllm/test_vllm_common.py @@ -1,5 +1,5 @@ import functools -from typing import Callable +from collections.abc import Callable import pytest @@ -61,11 +61,11 @@ def chat_lm_with_system_message() -> VLLM: @pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed") class TestVLLM(BaseLanguageModelTest): - @pytest.fixture() + @pytest.fixture def lm(self, chat_lm: VLLM) -> VLLM: return chat_lm - @pytest.fixture() + @pytest.fixture def chat_lm(self, chat_lm: VLLM) -> VLLM: return chat_lm diff --git a/tests/core/language_model/vllm/test_vllm_serve_lm.py b/tests/core/language_model/vllm/test_vllm_serve_lm.py index b68a3952..d455c43e 100644 --- a/tests/core/language_model/vllm/test_vllm_serve_lm.py +++ b/tests/core/language_model/vllm/test_vllm_serve_lm.py @@ -86,11 +86,11 @@ def test_stop_terminates_process() -> None: @pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed") class TestVLLMServeLM(BaseLanguageModelTest): - @pytest.fixture() + @pytest.fixture def lm(self, chat_lm: VLLMServeLM) -> VLLMServeLM: return chat_lm - @pytest.fixture() + @pytest.fixture def chat_lm(self, chat_lm: VLLMServeLM) -> VLLMServeLM: return chat_lm diff --git a/tests/core/language_model/vllm/test_vllm_specific.py b/tests/core/language_model/vllm/test_vllm_specific.py index f4ee6a95..56ef9980 100644 --- a/tests/core/language_model/vllm/test_vllm_specific.py +++ b/tests/core/language_model/vllm/test_vllm_specific.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import Any, Callable, Generator +from collections.abc import Callable, Generator +from typing import Any from unittest.mock import patch import pytest diff --git a/tests/core/metric/conftest.py b/tests/core/metric/conftest.py index e04f96de..c7ec801a 100644 --- a/tests/core/metric/conftest.py +++ b/tests/core/metric/conftest.py @@ -20,7 +20,7 @@ def as_lm_output(request: pytest.FixtureRequest) -> str: return request.param -@pytest.fixture() +@pytest.fixture def lm_outputs(request: pytest.FixtureRequest, as_lm_output: str) -> list[str] | list[LMOutput]: """ Fixture that converts parameterized string lists to either strings or LMOutput objects. diff --git a/tests/core/metric/test_finish_reason.py b/tests/core/metric/test_finish_reason.py index d06572c8..f6bcd343 100644 --- a/tests/core/metric/test_finish_reason.py +++ b/tests/core/metric/test_finish_reason.py @@ -44,8 +44,8 @@ def test_finish_reason_count_functionality( """Test FinishReasonCount metric functionality with various finish reason combinations.""" metric = FinishReasonCount() - lm_outputs = [LMOutput(text=f"Response {i+1}", finish_reason=reason) for i, reason in enumerate(finish_reasons)] - references_list = [[f"ref{i+1}"] for i in range(len(finish_reasons))] + lm_outputs = [LMOutput(text=f"Response {i + 1}", finish_reason=reason) for i, reason in enumerate(finish_reasons)] + references_list = [[f"ref{i + 1}"] for i in range(len(finish_reasons))] result = metric.evaluate(lm_outputs, references_list) diff --git a/tests/core/result_recorder/test_local_recorder.py b/tests/core/result_recorder/test_local_recorder.py index 6ad848b2..2c7f1484 100644 --- a/tests/core/result_recorder/test_local_recorder.py +++ b/tests/core/result_recorder/test_local_recorder.py @@ -12,7 +12,7 @@ ) -@pytest.fixture() +@pytest.fixture def temp_dir() -> None: with tempfile.TemporaryDirectory() as tmp_path: yield tmp_path diff --git a/tests/core/tokenizer/test_mecab.py b/tests/core/tokenizer/test_mecab.py index 688a48b5..481e7225 100644 --- a/tests/core/tokenizer/test_mecab.py +++ b/tests/core/tokenizer/test_mecab.py @@ -5,7 +5,7 @@ from flexeval import MecabTokenizer -@pytest.fixture() +@pytest.fixture def mocked_fugashi_tagger() -> Mock: # Mock fugashi.Tagger as it requires downloading the MeCab dictionary with patch("fugashi.Tagger") as mock_tagger: diff --git a/tests/dummy_modules/pairwise_comparison.py b/tests/dummy_modules/pairwise_comparison.py index becd81ee..349f4e82 100644 --- a/tests/dummy_modules/pairwise_comparison.py +++ b/tests/dummy_modules/pairwise_comparison.py @@ -19,4 +19,4 @@ def compute_scores(self, match_results: list[tuple[str, str, Winner]]) -> dict[s for model1, model2, _ in match_results: all_model_names.add(model1) all_model_names.add(model2) - return {model_name: 1.0 for model_name in all_model_names} + return dict.fromkeys(all_model_names, 1.0) diff --git a/tests/scripts/test_flexeval_lm.py b/tests/scripts/test_flexeval_lm.py index c9ab974c..49ac1d11 100644 --- a/tests/scripts/test_flexeval_lm.py +++ b/tests/scripts/test_flexeval_lm.py @@ -326,7 +326,7 @@ def evaluate( check_if_eval_results_are_correctly_saved(f) -@pytest.fixture() +@pytest.fixture def mock_eval_data() -> dict: return {"setup": "dummy_setup_object", "config": {"task": "test", "metric": "acc"}, "group": "test_group"}