Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions flexeval/core/eval_setups.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class ChatResponse(EvalSetup):
metrics: list[Metric] | Metric | None = None
batch_size: int = 4
max_instances: int | None = None
random_seed: int = 42

def evaluate_lm(
self,
Expand All @@ -60,6 +61,8 @@ def evaluate_lm(
metrics = [metrics]
metrics += [FinishReasonCount(), OutputLengthStats()]

language_model.set_random_seed(self.random_seed)

return evaluate_chat_response(
language_model=language_model,
gen_kwargs=self.gen_kwargs,
Expand All @@ -85,6 +88,7 @@ class Generation(EvalSetup):
metrics: list[Metric] | Metric | None = None
batch_size: int = 4
max_instances: int | None = None
random_seed: int = 42

def __post_init__(self) -> None:
if isinstance(self.prompt_template, str):
Expand All @@ -99,6 +103,8 @@ def evaluate_lm(
metrics = [metrics]
metrics += [FinishReasonCount(), OutputLengthStats()]

language_model.set_random_seed(self.random_seed)

return evaluate_generation(
language_model=language_model,
gen_kwargs=self.gen_kwargs,
Expand Down
7 changes: 7 additions & 0 deletions flexeval/core/language_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ def __init__(
self.string_processors = string_processors
self.tools = tools

def set_random_seed(self, seed: int) -> None:
"""
A method to set random seed for deterministic behavior.
"""
msg = "set_random_seed is not implemented."
raise NotImplementedError(msg)

def _batch_complete_text(
self,
text_list: list[str],
Expand Down
3 changes: 3 additions & 0 deletions flexeval/core/language_model/hf_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ def __init__(
logger.info(f"random seed: {random_seed}")
transformers.set_seed(random_seed)

def set_random_seed(self, seed: int) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also add a test case for hf_lm.py?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it was my oversight.
Added tests in e1cb5aa

transformers.set_seed(seed)

@staticmethod
def load_model(method: Callable) -> Callable:
"""Decorator to load the model lazily."""
Expand Down
3 changes: 3 additions & 0 deletions flexeval/core/language_model/litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def __init__(
)
self.ignore_seed = ignore_seed

def set_random_seed(self, seed: int) -> None:
self.default_gen_kwargs["seed"] = seed

def _batch_complete_text(
self,
text_list: list[str],
Expand Down
6 changes: 6 additions & 0 deletions flexeval/core/language_model/openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def __init__(
self.first_wait_time = first_wait_time
self.max_wait_time = max_wait_time

def set_random_seed(self, seed: int) -> None:
self.default_gen_kwargs["seed"] = seed

def _parallel_run_chatgpt(
self,
messages_list: list[list[dict[str, Any]]],
Expand Down Expand Up @@ -414,6 +417,9 @@ def _parallel_run_chatgpt(
]
return [future.result() for future in futures]

def set_random_seed(self, seed: int) -> None:
self.default_gen_kwargs["seed"] = seed

def _batch_complete_text(
self,
text_list: list[str],
Expand Down
3 changes: 3 additions & 0 deletions flexeval/core/language_model/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,9 @@ def wrapper(self: VLLM, *args: tuple, **kwargs: dict) -> Callable:

return wrapper

def set_random_seed(self, seed: int) -> None:
self.default_gen_kwargs["seed"] = seed

@load_model
def _batch_complete_text(
self,
Expand Down
23 changes: 23 additions & 0 deletions tests/core/language_model/test_hf_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,24 @@ def test_if_random_seed_fixes_the_lm_outputs(lm_init_func: Callable[..., Hugging
assert len(completions) > 1


def test_if_set_random_seed_fixes_the_lm_outputs(lm: HuggingFaceLM) -> None:
# first check if the outputs are different without fixing the seed
completions = set()
for i in range(3):
lm.set_random_seed(i)
completion = lm.complete_text(["<s>"], do_sample=True)[0]
completions.add(completion.text)
assert len(completions) > 1

# then check if the outputs are the same with fixing the seed
completions = set()
for _ in range(3):
lm.set_random_seed(42)
completion = lm.complete_text(["<s>"], do_sample=True)[0]
completions.add(completion.text)
assert len(completions) == 1


def test_if_custom_chat_template_is_given(lm_init_func: Callable[..., HuggingFaceLM]) -> None:
# To verify that the template specified in `custom_chat_template` is passed to `tokenizer.apply_chat_template()`,
# prepare a template where the model is expected to output "0 0..." for any input.
Expand Down Expand Up @@ -494,3 +512,8 @@ def mock_apply_chat_template(messages: list[list[dict[str, Any]]], **kwargs) ->
finally:
# Restore original method
chat_lm_with_system_message.tokenizer.apply_chat_template = original_apply_chat_template


def test_set_random_seed(lm: HuggingFaceLM) -> None:
# check that method is implemented
assert lm.set_random_seed(42) is None
6 changes: 6 additions & 0 deletions tests/core/language_model/test_litellm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,9 @@ def test_if_not_ignore_seed() -> None:
with patch.object(OpenAIChatAPI, "_batch_complete_text", return_value=[LMOutput("ChatGPT.")]) as mock_method:
chat_lm.complete_text(text, stop_sequences=None, max_new_tokens=None, temperature=0.7, seed=42)
mock_method.assert_called_once_with([text], None, None, temperature=0.7, seed=42)


@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed")
def test_set_random_seed(chat_lm: OpenAIChatAPI) -> None:
chat_lm.set_random_seed(42)
assert chat_lm.default_gen_kwargs["seed"] == 42
6 changes: 6 additions & 0 deletions tests/core/language_model/test_openai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,9 @@ def test_model_limit_new_tokens_complete_text(chat_lm: OpenAIChatAPI, caplog: py
assert len(caplog.records) >= 1
assert any(record.msg.startswith("The specified `max_new_tokens` (128) exceeds") for record in caplog.records)
caplog.clear()


@pytest.mark.skipif(not is_openai_enabled(), reason="OpenAI is not installed")
def test_set_random_seed(chat_lm: OpenAIChatAPI) -> None:
chat_lm.set_random_seed(42)
assert chat_lm.default_gen_kwargs["seed"] == 42
6 changes: 6 additions & 0 deletions tests/core/language_model/vllm/test_vllm_serve_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,9 @@ def test_generate_chat_response_if_number_of_tools_and_messages_not_equal(
self, chat_lm_for_tool_calling: LanguageModel
) -> None:
pass


@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed")
def test_set_random_seed(chat_lm: VLLMServeLM) -> None:
chat_lm.set_random_seed(42)
assert chat_lm.default_gen_kwargs["seed"] == 42
6 changes: 6 additions & 0 deletions tests/core/language_model/vllm/test_vllm_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,9 @@ def mock_apply_chat_template(messages: list[list[dict[str, Any]]], **kwargs) ->
finally:
# Restore original method
chat_lm_with_system_message.tokenizer.apply_chat_template = original_apply_chat_template


@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed")
def test_set_random_seed(chat_lm: VLLM) -> None:
chat_lm.set_random_seed(42)
assert chat_lm.default_gen_kwargs["seed"] == 42
3 changes: 3 additions & 0 deletions tests/dummy_modules/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def _batch_compute_log_probs(
) -> list[float]:
return [-1.0] * len(text_list)

def set_random_seed(self, seed: int) -> None:
pass

def _batch_generate_chat_response(
self,
chat_messages_list: list[list[dict[str, str]]],
Expand Down
3 changes: 3 additions & 0 deletions tests/dummy_modules/reward_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ def __init__(self, response: str = "[[A]]") -> None:
super().__init__()
self.response = response

def set_random_seed(self, seed: int) -> None:
pass

def _batch_complete_text(self, text_list: list[str], **kwargs) -> list[LMOutput]:
return [LMOutput(text=self.response, finish_reason="length") for _ in text_list]

Expand Down
Loading