diff --git a/flexeval/core/eval_setups.py b/flexeval/core/eval_setups.py index e922b5e7..b4822629 100644 --- a/flexeval/core/eval_setups.py +++ b/flexeval/core/eval_setups.py @@ -50,6 +50,7 @@ class ChatResponse(EvalSetup): metrics: list[Metric] | Metric | None = None batch_size: int = 4 max_instances: int | None = None + random_seed: int = 42 def evaluate_lm( self, @@ -60,6 +61,8 @@ def evaluate_lm( metrics = [metrics] metrics += [FinishReasonCount(), OutputLengthStats()] + language_model.set_random_seed(self.random_seed) + return evaluate_chat_response( language_model=language_model, gen_kwargs=self.gen_kwargs, @@ -85,6 +88,7 @@ class Generation(EvalSetup): metrics: list[Metric] | Metric | None = None batch_size: int = 4 max_instances: int | None = None + random_seed: int = 42 def __post_init__(self) -> None: if isinstance(self.prompt_template, str): @@ -99,6 +103,8 @@ def evaluate_lm( metrics = [metrics] metrics += [FinishReasonCount(), OutputLengthStats()] + language_model.set_random_seed(self.random_seed) + return evaluate_generation( language_model=language_model, gen_kwargs=self.gen_kwargs, diff --git a/flexeval/core/language_model/base.py b/flexeval/core/language_model/base.py index 58fa6a20..0993b418 100644 --- a/flexeval/core/language_model/base.py +++ b/flexeval/core/language_model/base.py @@ -66,6 +66,13 @@ def __init__( self.string_processors = string_processors self.tools = tools + def set_random_seed(self, seed: int) -> None: + """ + A method to set random seed for deterministic behavior. + """ + msg = "set_random_seed is not implemented." + raise NotImplementedError(msg) + def _batch_complete_text( self, text_list: list[str], diff --git a/flexeval/core/language_model/hf_lm.py b/flexeval/core/language_model/hf_lm.py index f1a31c48..ae8e09fb 100644 --- a/flexeval/core/language_model/hf_lm.py +++ b/flexeval/core/language_model/hf_lm.py @@ -229,6 +229,9 @@ def __init__( logger.info(f"random seed: {random_seed}") transformers.set_seed(random_seed) + def set_random_seed(self, seed: int) -> None: + transformers.set_seed(seed) + @staticmethod def load_model(method: Callable) -> Callable: """Decorator to load the model lazily.""" diff --git a/flexeval/core/language_model/litellm_api.py b/flexeval/core/language_model/litellm_api.py index 887619ed..15301e63 100644 --- a/flexeval/core/language_model/litellm_api.py +++ b/flexeval/core/language_model/litellm_api.py @@ -68,6 +68,9 @@ def __init__( ) self.ignore_seed = ignore_seed + def set_random_seed(self, seed: int) -> None: + self.default_gen_kwargs["seed"] = seed + def _batch_complete_text( self, text_list: list[str], diff --git a/flexeval/core/language_model/openai_api.py b/flexeval/core/language_model/openai_api.py index 477a55d7..e957bf56 100644 --- a/flexeval/core/language_model/openai_api.py +++ b/flexeval/core/language_model/openai_api.py @@ -119,6 +119,9 @@ def __init__( self.first_wait_time = first_wait_time self.max_wait_time = max_wait_time + def set_random_seed(self, seed: int) -> None: + self.default_gen_kwargs["seed"] = seed + def _parallel_run_chatgpt( self, messages_list: list[list[dict[str, Any]]], @@ -414,6 +417,9 @@ def _parallel_run_chatgpt( ] return [future.result() for future in futures] + def set_random_seed(self, seed: int) -> None: + self.default_gen_kwargs["seed"] = seed + def _batch_complete_text( self, text_list: list[str], diff --git a/flexeval/core/language_model/vllm_model.py b/flexeval/core/language_model/vllm_model.py index a16d307a..aa9cb245 100644 --- a/flexeval/core/language_model/vllm_model.py +++ b/flexeval/core/language_model/vllm_model.py @@ -156,6 +156,9 @@ def wrapper(self: VLLM, *args: tuple, **kwargs: dict) -> Callable: return wrapper + def set_random_seed(self, seed: int) -> None: + self.default_gen_kwargs["seed"] = seed + @load_model def _batch_complete_text( self, diff --git a/tests/core/language_model/test_hf_lm.py b/tests/core/language_model/test_hf_lm.py index 388041fd..591285fc 100644 --- a/tests/core/language_model/test_hf_lm.py +++ b/tests/core/language_model/test_hf_lm.py @@ -232,6 +232,24 @@ def test_if_random_seed_fixes_the_lm_outputs(lm_init_func: Callable[..., Hugging assert len(completions) > 1 +def test_if_set_random_seed_fixes_the_lm_outputs(lm: HuggingFaceLM) -> None: + # first check if the outputs are different without fixing the seed + completions = set() + for i in range(3): + lm.set_random_seed(i) + completion = lm.complete_text([""], do_sample=True)[0] + completions.add(completion.text) + assert len(completions) > 1 + + # then check if the outputs are the same with fixing the seed + completions = set() + for _ in range(3): + lm.set_random_seed(42) + completion = lm.complete_text([""], do_sample=True)[0] + completions.add(completion.text) + assert len(completions) == 1 + + def test_if_custom_chat_template_is_given(lm_init_func: Callable[..., HuggingFaceLM]) -> None: # To verify that the template specified in `custom_chat_template` is passed to `tokenizer.apply_chat_template()`, # prepare a template where the model is expected to output "0 0..." for any input. @@ -494,3 +512,8 @@ def mock_apply_chat_template(messages: list[list[dict[str, Any]]], **kwargs) -> finally: # Restore original method chat_lm_with_system_message.tokenizer.apply_chat_template = original_apply_chat_template + + +def test_set_random_seed(lm: HuggingFaceLM) -> None: + # check that method is implemented + assert lm.set_random_seed(42) is None diff --git a/tests/core/language_model/test_litellm_api.py b/tests/core/language_model/test_litellm_api.py index 06894d5c..49dde7dc 100644 --- a/tests/core/language_model/test_litellm_api.py +++ b/tests/core/language_model/test_litellm_api.py @@ -86,3 +86,9 @@ def test_if_not_ignore_seed() -> None: with patch.object(OpenAIChatAPI, "_batch_complete_text", return_value=[LMOutput("ChatGPT.")]) as mock_method: chat_lm.complete_text(text, stop_sequences=None, max_new_tokens=None, temperature=0.7, seed=42) mock_method.assert_called_once_with([text], None, None, temperature=0.7, seed=42) + + +@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") +def test_set_random_seed(chat_lm: OpenAIChatAPI) -> None: + chat_lm.set_random_seed(42) + assert chat_lm.default_gen_kwargs["seed"] == 42 diff --git a/tests/core/language_model/test_openai_api.py b/tests/core/language_model/test_openai_api.py index 3f3346da..07e0d7cd 100644 --- a/tests/core/language_model/test_openai_api.py +++ b/tests/core/language_model/test_openai_api.py @@ -161,3 +161,9 @@ def test_model_limit_new_tokens_complete_text(chat_lm: OpenAIChatAPI, caplog: py assert len(caplog.records) >= 1 assert any(record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records) caplog.clear() + + +@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed") +def test_set_random_seed(chat_lm: OpenAIChatAPI) -> None: + chat_lm.set_random_seed(42) + assert chat_lm.default_gen_kwargs["seed"] == 42 diff --git a/tests/core/language_model/vllm/test_vllm_serve_lm.py b/tests/core/language_model/vllm/test_vllm_serve_lm.py index 72a3700b..b68a3952 100644 --- a/tests/core/language_model/vllm/test_vllm_serve_lm.py +++ b/tests/core/language_model/vllm/test_vllm_serve_lm.py @@ -115,3 +115,9 @@ def test_generate_chat_response_if_number_of_tools_and_messages_not_equal( self, chat_lm_for_tool_calling: LanguageModel ) -> None: pass + + +@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed") +def test_set_random_seed(chat_lm: VLLMServeLM) -> None: + chat_lm.set_random_seed(42) + assert chat_lm.default_gen_kwargs["seed"] == 42 diff --git a/tests/core/language_model/vllm/test_vllm_specific.py b/tests/core/language_model/vllm/test_vllm_specific.py index b1edca6e..f4ee6a95 100644 --- a/tests/core/language_model/vllm/test_vllm_specific.py +++ b/tests/core/language_model/vllm/test_vllm_specific.py @@ -232,3 +232,9 @@ def mock_apply_chat_template(messages: list[list[dict[str, Any]]], **kwargs) -> finally: # Restore original method chat_lm_with_system_message.tokenizer.apply_chat_template = original_apply_chat_template + + +@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed") +def test_set_random_seed(chat_lm: VLLM) -> None: + chat_lm.set_random_seed(42) + assert chat_lm.default_gen_kwargs["seed"] == 42 diff --git a/tests/dummy_modules/lm.py b/tests/dummy_modules/lm.py index c189ca0d..2790b6d1 100644 --- a/tests/dummy_modules/lm.py +++ b/tests/dummy_modules/lm.py @@ -23,6 +23,9 @@ def _batch_compute_log_probs( ) -> list[float]: return [-1.0] * len(text_list) + def set_random_seed(self, seed: int) -> None: + pass + def _batch_generate_chat_response( self, chat_messages_list: list[list[dict[str, str]]], diff --git a/tests/dummy_modules/reward_lm.py b/tests/dummy_modules/reward_lm.py index d6257d5a..b1cd76c8 100644 --- a/tests/dummy_modules/reward_lm.py +++ b/tests/dummy_modules/reward_lm.py @@ -11,6 +11,9 @@ def __init__(self, response: str = "[[A]]") -> None: super().__init__() self.response = response + def set_random_seed(self, seed: int) -> None: + pass + def _batch_complete_text(self, text_list: list[str], **kwargs) -> list[LMOutput]: return [LMOutput(text=self.response, finish_reason="length") for _ in text_list]