Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions flexeval/scripts/flexeval_lm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import copy
import dataclasses
import json
import os
import re
Expand All @@ -13,6 +15,7 @@
import _jsonnet
from jsonargparse import ActionConfigFile, ArgumentParser, Namespace
from loguru import logger
from pydantic.types import PositiveInt

from flexeval import EvalSetup, LanguageModel, LocalRecorder, ResultRecorder
from flexeval.utils.module_utils import ConfigNameResolver
Expand Down Expand Up @@ -44,6 +47,63 @@ def as_dict(self: Namespace) -> dict[str, Any]:
return dic


def maybe_replace_random_seed(
eval_setup: EvalSetup, config_content: dict, seed_increment: int
) -> tuple[EvalSetup, dict]:
"""
Increments the random seed if the EvalSetup object has a 'random_seed' attribute.

A new EvalSetup instance and a deep-copied config dictionary are returned
with the updated seed value.
"""
if hasattr(eval_setup, "random_seed"):
new_random_seed = eval_setup.random_seed + seed_increment
new_eval_setup = dataclasses.replace(eval_setup, random_seed=new_random_seed)
new_config_content = copy.deepcopy(config_content)
new_config_content["init_args"]["random_seed"] = new_random_seed
return new_eval_setup, new_config_content
return eval_setup, config_content


def generate_eval_entries(
eval_setup: EvalSetup, config_content: dict, group: str | None, num_repeats: PositiveInt
) -> list:
"""
Generates a list of evaluation entries based on repeat count and group name.

If the number of repeats (num_repeats) is 1 or less, no run-specific metadata
(i.e., 'runN') is appended. If it is 2 or more, each entry is indexed.

Args:
eval_setup (EvalSetup): The evaluation setup object.
config_content (dict): The configuration content (dictionary) for the setup.
group (str | None): The group name for the evaluation, e.g., aio, mt-bench, etc...
This forms the base of the metadata (e.g., 'group/runN').
If None, the metadata will only be 'runN' for repeats.
num_repeats (int): The number of times the evaluation should be repeated.

Returns:
list[tuple[EvalSetup, dict, str | None]]: A list of tuples, each containing:
(EvalSetup object, config dictionary, metadata string | None)
"""
entries: list[tuple[EvalSetup, dict, str | None]] = []

if num_repeats <= 0:
msg = f"num_repeats must be positive, but {num_repeats} is given"
raise ValueError(msg)

if num_repeats == 1:
metadata = group if group is not None else None
entries.append((eval_setup, config_content, metadata))
else:
for index in range(num_repeats):
metadata = f"{group}/run{index}" if group else f"run{index}"
new_eval_setup, new_config_content = maybe_replace_random_seed(eval_setup, config_content, index)
entries.append((new_eval_setup, new_config_content, metadata))

return entries


def main() -> None: # noqa: C901, PLR0912, PLR0915
parser = ArgumentParser(parser_mode="jsonnet")
parser.add_subclass_arguments(
Expand Down Expand Up @@ -98,6 +158,12 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915
default={},
help="Metadata to save in config.json",
)
parser.add_argument(
"--num_repeats",
type=PositiveInt,
default=1,
help="Number of times to repeat the evaluation",
)

config_name_resolver = ConfigNameResolver()
# Resolve the preset name to the path to the config file before parsing the arguments.
Expand Down Expand Up @@ -159,12 +225,28 @@ def main() -> None: # noqa: C901, PLR0912, PLR0915

# normalize args.eval_setup or args.eval_setups into a list of tuples,
# which contain (eval_setup, eval_setup_config, group)
eval_setups_and_metadata: list[EvalSetup, dict[str, Any], str | None] = []
eval_setups_and_metadata: list[tuple[EvalSetup, dict[str, Any], str | None]] = []

if args.eval_setup:
eval_setups_and_metadata.append((args.eval_setup, config_dict["eval_setup"], None))
eval_setups_and_metadata.extend(
generate_eval_entries(
eval_setup=args.eval_setup,
config_content=config_dict["eval_setup"],
group=None,
num_repeats=args.num_repeats,
)
)

if args.eval_setups:
for group, eval_setup in args.eval_setups.items():
eval_setups_and_metadata.append((eval_setup, config_dict["eval_setups"][group], group))
eval_setups_and_metadata.extend(
generate_eval_entries(
eval_setup=eval_setup,
config_content=config_dict["eval_setups"][group],
group=group,
num_repeats=args.num_repeats,
)
)

# run evaluation
for eval_setup, eval_setup_config, group in eval_setups_and_metadata:
Expand Down
98 changes: 98 additions & 0 deletions tests/scripts/test_flexeval_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@

import pytest

from flexeval import Generation, Perplexity
from flexeval.core.result_recorder.local_recorder import CONFIG_FILE_NAME, METRIC_FILE_NAME, OUTPUTS_FILE_NAME
from flexeval.scripts.flexeval_lm import generate_eval_entries, maybe_replace_random_seed
from tests.dummy_modules import DummyGenerationDataset, DummyTextDataset

# fmt: off
CHAT_RESPONSE_CMD = [
Expand Down Expand Up @@ -321,3 +324,98 @@ def evaluate(
assert result.returncode == os.EX_OK

check_if_eval_results_are_correctly_saved(f)


@pytest.fixture()
def mock_eval_data() -> dict:
return {"setup": "dummy_setup_object", "config": {"task": "test", "metric": "acc"}, "group": "test_group"}


def test_no_repeat_with_group(mock_eval_data: dict) -> None:
result = generate_eval_entries(
eval_setup=mock_eval_data["setup"],
config_content=mock_eval_data["config"],
group=mock_eval_data["group"],
num_repeats=1,
)
expected = [("dummy_setup_object", {"task": "test", "metric": "acc"}, "test_group")]
assert result == expected
assert len(result) == 1


def test_no_repeat_no_group(mock_eval_data: dict) -> None:
result = generate_eval_entries(
eval_setup=mock_eval_data["setup"], config_content=mock_eval_data["config"], group=None, num_repeats=1
)
expected = [("dummy_setup_object", {"task": "test", "metric": "acc"}, None)]
assert result == expected
assert len(result) == 1


def test_multiple_repeats_with_group(mock_eval_data: dict) -> None:
num_repeats = 3
result = generate_eval_entries(
eval_setup=mock_eval_data["setup"],
config_content=mock_eval_data["config"],
group=mock_eval_data["group"],
num_repeats=num_repeats,
)
expected_metadatas = ["test_group/run0", "test_group/run1", "test_group/run2"]

assert len(result) == num_repeats
for i, entry in enumerate(result):
assert entry[0] == mock_eval_data["setup"]
assert entry[1] == mock_eval_data["config"]
assert entry[2] == expected_metadatas[i]


def test_non_positive_num_repeats(mock_eval_data: dict) -> None:
with pytest.raises(ValueError):
generate_eval_entries(
eval_setup=mock_eval_data["setup"],
config_content=mock_eval_data["config"],
group=mock_eval_data["group"],
num_repeats=-1,
)


def test_maybe_replace_random_seed() -> None:
# eval setup w/ random seed: should update random seed
generation_eval_setup = Generation(
eval_dataset=DummyGenerationDataset(), gen_kwargs={}, prompt_template="dummy", random_seed=42
)
generation_config = {"init_args": {"random_seed": 42}}
new_generation_eval_setup, new_generation_config = maybe_replace_random_seed(
generation_eval_setup, generation_config, seed_increment=1
)
assert new_generation_eval_setup.random_seed == 43
assert new_generation_config["init_args"]["random_seed"] == 43

# eval setup w/o random seed: should not change anything
perplexity_eval_setup = Perplexity(eval_dataset=DummyTextDataset())
perplexity_config = {}
new_perplexity_eval_setup, new_perplexity_config = maybe_replace_random_seed(
perplexity_eval_setup, perplexity_config, seed_increment=1
)
assert id(new_perplexity_config) == id(perplexity_config)
assert id(new_perplexity_eval_setup) == id(perplexity_eval_setup)


@pytest.mark.parametrize(
"num_repeats",
[1, 3, 5],
)
def test_flexeval_lm_with_num_repeats(num_repeats: int) -> None:
with tempfile.TemporaryDirectory() as f:
# fmt: off
command = [*CHAT_RESPONSE_CMD, "--num_repeats", str(num_repeats), "--save_dir", f]
# fmt: on

result = subprocess.run(command, check=False)
assert result.returncode == os.EX_OK

if num_repeats == 1:
check_if_eval_results_are_correctly_saved(f)
else:
for i in range(num_repeats):
check_if_eval_results_are_correctly_saved(f"{f}/run{i}")