diff --git a/flexeval/scripts/flexeval_lm.py b/flexeval/scripts/flexeval_lm.py index c0783663..32d47cab 100644 --- a/flexeval/scripts/flexeval_lm.py +++ b/flexeval/scripts/flexeval_lm.py @@ -1,5 +1,7 @@ from __future__ import annotations +import copy +import dataclasses import json import os import re @@ -13,6 +15,7 @@ import _jsonnet from jsonargparse import ActionConfigFile, ArgumentParser, Namespace from loguru import logger +from pydantic.types import PositiveInt from flexeval import EvalSetup, LanguageModel, LocalRecorder, ResultRecorder from flexeval.utils.module_utils import ConfigNameResolver @@ -44,6 +47,63 @@ def as_dict(self: Namespace) -> dict[str, Any]: return dic +def maybe_replace_random_seed( + eval_setup: EvalSetup, config_content: dict, seed_increment: int +) -> tuple[EvalSetup, dict]: + """ + Increments the random seed if the EvalSetup object has a 'random_seed' attribute. + + A new EvalSetup instance and a deep-copied config dictionary are returned + with the updated seed value. + """ + if hasattr(eval_setup, "random_seed"): + new_random_seed = eval_setup.random_seed + seed_increment + new_eval_setup = dataclasses.replace(eval_setup, random_seed=new_random_seed) + new_config_content = copy.deepcopy(config_content) + new_config_content["init_args"]["random_seed"] = new_random_seed + return new_eval_setup, new_config_content + return eval_setup, config_content + + +def generate_eval_entries( + eval_setup: EvalSetup, config_content: dict, group: str | None, num_repeats: PositiveInt +) -> list: + """ + Generates a list of evaluation entries based on repeat count and group name. + + If the number of repeats (num_repeats) is 1 or less, no run-specific metadata + (i.e., 'runN') is appended. If it is 2 or more, each entry is indexed. + + Args: + eval_setup (EvalSetup): The evaluation setup object. + config_content (dict): The configuration content (dictionary) for the setup. + group (str | None): The group name for the evaluation, e.g., aio, mt-bench, etc... + This forms the base of the metadata (e.g., 'group/runN'). + If None, the metadata will only be 'runN' for repeats. + num_repeats (int): The number of times the evaluation should be repeated. + + Returns: + list[tuple[EvalSetup, dict, str | None]]: A list of tuples, each containing: + (EvalSetup object, config dictionary, metadata string | None) + """ + entries: list[tuple[EvalSetup, dict, str | None]] = [] + + if num_repeats <= 0: + msg = f"num_repeats must be positive, but {num_repeats} is given" + raise ValueError(msg) + + if num_repeats == 1: + metadata = group if group is not None else None + entries.append((eval_setup, config_content, metadata)) + else: + for index in range(num_repeats): + metadata = f"{group}/run{index}" if group else f"run{index}" + new_eval_setup, new_config_content = maybe_replace_random_seed(eval_setup, config_content, index) + entries.append((new_eval_setup, new_config_content, metadata)) + + return entries + + def main() -> None: # noqa: C901, PLR0912, PLR0915 parser = ArgumentParser(parser_mode="jsonnet") parser.add_subclass_arguments( @@ -98,6 +158,12 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915 default={}, help="Metadata to save in config.json", ) + parser.add_argument( + "--num_repeats", + type=PositiveInt, + default=1, + help="Number of times to repeat the evaluation", + ) config_name_resolver = ConfigNameResolver() # Resolve the preset name to the path to the config file before parsing the arguments. @@ -159,12 +225,28 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915 # normalize args.eval_setup or args.eval_setups into a list of tuples, # which contain (eval_setup, eval_setup_config, group) - eval_setups_and_metadata: list[EvalSetup, dict[str, Any], str | None] = [] + eval_setups_and_metadata: list[tuple[EvalSetup, dict[str, Any], str | None]] = [] + if args.eval_setup: - eval_setups_and_metadata.append((args.eval_setup, config_dict["eval_setup"], None)) + eval_setups_and_metadata.extend( + generate_eval_entries( + eval_setup=args.eval_setup, + config_content=config_dict["eval_setup"], + group=None, + num_repeats=args.num_repeats, + ) + ) + if args.eval_setups: for group, eval_setup in args.eval_setups.items(): - eval_setups_and_metadata.append((eval_setup, config_dict["eval_setups"][group], group)) + eval_setups_and_metadata.extend( + generate_eval_entries( + eval_setup=eval_setup, + config_content=config_dict["eval_setups"][group], + group=group, + num_repeats=args.num_repeats, + ) + ) # run evaluation for eval_setup, eval_setup_config, group in eval_setups_and_metadata: diff --git a/tests/scripts/test_flexeval_lm.py b/tests/scripts/test_flexeval_lm.py index a1e7374e..eab734e8 100644 --- a/tests/scripts/test_flexeval_lm.py +++ b/tests/scripts/test_flexeval_lm.py @@ -11,7 +11,10 @@ import pytest +from flexeval import Generation, Perplexity from flexeval.core.result_recorder.local_recorder import CONFIG_FILE_NAME, METRIC_FILE_NAME, OUTPUTS_FILE_NAME +from flexeval.scripts.flexeval_lm import generate_eval_entries, maybe_replace_random_seed +from tests.dummy_modules import DummyGenerationDataset, DummyTextDataset # fmt: off CHAT_RESPONSE_CMD = [ @@ -321,3 +324,98 @@ def evaluate( assert result.returncode == os.EX_OK check_if_eval_results_are_correctly_saved(f) + + +@pytest.fixture() +def mock_eval_data() -> dict: + return {"setup": "dummy_setup_object", "config": {"task": "test", "metric": "acc"}, "group": "test_group"} + + +def test_no_repeat_with_group(mock_eval_data: dict) -> None: + result = generate_eval_entries( + eval_setup=mock_eval_data["setup"], + config_content=mock_eval_data["config"], + group=mock_eval_data["group"], + num_repeats=1, + ) + expected = [("dummy_setup_object", {"task": "test", "metric": "acc"}, "test_group")] + assert result == expected + assert len(result) == 1 + + +def test_no_repeat_no_group(mock_eval_data: dict) -> None: + result = generate_eval_entries( + eval_setup=mock_eval_data["setup"], config_content=mock_eval_data["config"], group=None, num_repeats=1 + ) + expected = [("dummy_setup_object", {"task": "test", "metric": "acc"}, None)] + assert result == expected + assert len(result) == 1 + + +def test_multiple_repeats_with_group(mock_eval_data: dict) -> None: + num_repeats = 3 + result = generate_eval_entries( + eval_setup=mock_eval_data["setup"], + config_content=mock_eval_data["config"], + group=mock_eval_data["group"], + num_repeats=num_repeats, + ) + expected_metadatas = ["test_group/run0", "test_group/run1", "test_group/run2"] + + assert len(result) == num_repeats + for i, entry in enumerate(result): + assert entry[0] == mock_eval_data["setup"] + assert entry[1] == mock_eval_data["config"] + assert entry[2] == expected_metadatas[i] + + +def test_non_positive_num_repeats(mock_eval_data: dict) -> None: + with pytest.raises(ValueError): + generate_eval_entries( + eval_setup=mock_eval_data["setup"], + config_content=mock_eval_data["config"], + group=mock_eval_data["group"], + num_repeats=-1, + ) + + +def test_maybe_replace_random_seed() -> None: + # eval setup w/ random seed: should update random seed + generation_eval_setup = Generation( + eval_dataset=DummyGenerationDataset(), gen_kwargs={}, prompt_template="dummy", random_seed=42 + ) + generation_config = {"init_args": {"random_seed": 42}} + new_generation_eval_setup, new_generation_config = maybe_replace_random_seed( + generation_eval_setup, generation_config, seed_increment=1 + ) + assert new_generation_eval_setup.random_seed == 43 + assert new_generation_config["init_args"]["random_seed"] == 43 + + # eval setup w/o random seed: should not change anything + perplexity_eval_setup = Perplexity(eval_dataset=DummyTextDataset()) + perplexity_config = {} + new_perplexity_eval_setup, new_perplexity_config = maybe_replace_random_seed( + perplexity_eval_setup, perplexity_config, seed_increment=1 + ) + assert id(new_perplexity_config) == id(perplexity_config) + assert id(new_perplexity_eval_setup) == id(perplexity_eval_setup) + + +@pytest.mark.parametrize( + "num_repeats", + [1, 3, 5], +) +def test_flexeval_lm_with_num_repeats(num_repeats: int) -> None: + with tempfile.TemporaryDirectory() as f: + # fmt: off + command = [*CHAT_RESPONSE_CMD, "--num_repeats", str(num_repeats), "--save_dir", f] + # fmt: on + + result = subprocess.run(command, check=False) + assert result.returncode == os.EX_OK + + if num_repeats == 1: + check_if_eval_results_are_correctly_saved(f) + else: + for i in range(num_repeats): + check_if_eval_results_are_correctly_saved(f"{f}/run{i}")