diff --git a/flexeval/core/language_model/hf_lm.py b/flexeval/core/language_model/hf_lm.py index 565edaf3..f1a31c48 100644 --- a/flexeval/core/language_model/hf_lm.py +++ b/flexeval/core/language_model/hf_lm.py @@ -439,7 +439,7 @@ def _batch_compute_log_probs( # This is needed to correctly calculate the log probabilities of the first token. for i in range(batch_size): if prefix_list[i] == "": - prefix_list[i] = self.tokenizer.bos_token + prefix_list[i] = self.tokenizer.bos_token or self.tokenizer.eos_token prefix_encoding = tokenize_text_for_lm_prefix( prefix_list, diff --git a/flexeval/core/language_model/vllm_model.py b/flexeval/core/language_model/vllm_model.py index 1c35300f..a16d307a 100644 --- a/flexeval/core/language_model/vllm_model.py +++ b/flexeval/core/language_model/vllm_model.py @@ -282,7 +282,7 @@ def _batch_compute_log_probs( # This is needed to correctly calculate the log probabilities of the first token. for i in range(batch_size): if prefix_list[i] == "": - prefix_list[i] = self.tokenizer.bos_token + prefix_list[i] = self.tokenizer.bos_token or self.tokenizer.eos_token batch_prefix_ids = tokenize_text_for_lm_prefix( prefix_list, @@ -321,7 +321,9 @@ def _batch_compute_log_probs( chunk_end = min(chunk_start + max_length, sequence_length) chunk_batch_input_ids = [input_ids[chunk_start:chunk_end] for input_ids in batch_input_ids] chunk_batch_input_ids = [ - [self.tokenizer.bos_token_id] if len(chunk_input_ids) == 0 else chunk_input_ids + [self.tokenizer.bos_token_id or self.tokenizer.eos_token] + if len(chunk_input_ids) == 0 + else chunk_input_ids for chunk_input_ids in chunk_batch_input_ids ] chunk_batch_outputs: list[RequestOutput] = self.llm.generate( diff --git a/tests/core/language_model/vllm/test_vllm_specific.py b/tests/core/language_model/vllm/test_vllm_specific.py index d629be69..b1edca6e 100644 --- a/tests/core/language_model/vllm/test_vllm_specific.py +++ b/tests/core/language_model/vllm/test_vllm_specific.py @@ -7,7 +7,7 @@ import pytest from transformers import AutoTokenizer -from flexeval.core.language_model import VLLM, HuggingFaceLM, LanguageModel +from flexeval.core.language_model import VLLM, HuggingFaceLM from tests.conftest import is_vllm_enabled from tests.dummy_modules.tool_parser import DummyToolParser @@ -31,6 +31,23 @@ def chat_lm() -> Generator[VLLM, None, None]: cleanup_dist_env_and_memory() +@pytest.fixture(scope="module") +def chat_lm_qwen() -> Generator[VLLM, None, None]: + llm = VLLM( + model="Qwen/Qwen3-0.6B-Base", + model_kwargs={ + "seed": 42, + "gpu_memory_utilization": 0.1, + "enforce_eager": True, + "disable_custom_all_reduce": True, + }, + ) + yield llm + from vllm.distributed.parallel_state import cleanup_dist_env_and_memory + + cleanup_dist_env_and_memory() + + @pytest.fixture(scope="module") def chat_lm_for_tool_calling() -> Generator[VLLM, None, None]: tool_parser = DummyToolParser() @@ -60,7 +77,13 @@ def hf_lm(model_name: str = "sbintuitions/tiny-lm-chat") -> HuggingFaceLM: @pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed") -def test_batch_compute_log_probs_approximates_hf_lm(chat_lm: LanguageModel, hf_lm: HuggingFaceLM) -> None: +@pytest.mark.parametrize("chat_lm_name", ["chat_lm", "chat_lm_qwen"]) +def test_batch_compute_log_probs_approximates_hf_lm( + request: pytest.FixtureRequest, + chat_lm_name: str, + hf_lm: HuggingFaceLM, +) -> None: + chat_lm = request.getfixturevalue(chat_lm_name) prefix_list = ["それは正しい日本語ですか?"] text_list = ["これは正しい日本語です。"]