Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flexeval/core/language_model/hf_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def _batch_compute_log_probs(
# This is needed to correctly calculate the log probabilities of the first token.
for i in range(batch_size):
if prefix_list[i] == "":
prefix_list[i] = self.tokenizer.bos_token
prefix_list[i] = self.tokenizer.bos_token or self.tokenizer.eos_token

prefix_encoding = tokenize_text_for_lm_prefix(
prefix_list,
Expand Down
6 changes: 4 additions & 2 deletions flexeval/core/language_model/vllm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _batch_compute_log_probs(
# This is needed to correctly calculate the log probabilities of the first token.
for i in range(batch_size):
if prefix_list[i] == "":
prefix_list[i] = self.tokenizer.bos_token
prefix_list[i] = self.tokenizer.bos_token or self.tokenizer.eos_token

batch_prefix_ids = tokenize_text_for_lm_prefix(
prefix_list,
Expand Down Expand Up @@ -321,7 +321,9 @@ def _batch_compute_log_probs(
chunk_end = min(chunk_start + max_length, sequence_length)
chunk_batch_input_ids = [input_ids[chunk_start:chunk_end] for input_ids in batch_input_ids]
chunk_batch_input_ids = [
[self.tokenizer.bos_token_id] if len(chunk_input_ids) == 0 else chunk_input_ids
[self.tokenizer.bos_token_id or self.tokenizer.eos_token]
if len(chunk_input_ids) == 0
else chunk_input_ids
for chunk_input_ids in chunk_batch_input_ids
]
chunk_batch_outputs: list[RequestOutput] = self.llm.generate(
Expand Down
27 changes: 25 additions & 2 deletions tests/core/language_model/vllm/test_vllm_specific.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from transformers import AutoTokenizer

from flexeval.core.language_model import VLLM, HuggingFaceLM, LanguageModel
from flexeval.core.language_model import VLLM, HuggingFaceLM
from tests.conftest import is_vllm_enabled
from tests.dummy_modules.tool_parser import DummyToolParser

Expand All @@ -31,6 +31,23 @@ def chat_lm() -> Generator[VLLM, None, None]:
cleanup_dist_env_and_memory()


@pytest.fixture(scope="module")
def chat_lm_qwen() -> Generator[VLLM, None, None]:
llm = VLLM(
model="Qwen/Qwen3-0.6B-Base",
model_kwargs={
"seed": 42,
"gpu_memory_utilization": 0.1,
"enforce_eager": True,
"disable_custom_all_reduce": True,
},
)
yield llm
from vllm.distributed.parallel_state import cleanup_dist_env_and_memory

cleanup_dist_env_and_memory()


@pytest.fixture(scope="module")
def chat_lm_for_tool_calling() -> Generator[VLLM, None, None]:
tool_parser = DummyToolParser()
Expand Down Expand Up @@ -60,7 +77,13 @@ def hf_lm(model_name: str = "sbintuitions/tiny-lm-chat") -> HuggingFaceLM:


@pytest.mark.skipif(not is_vllm_enabled(), reason="vllm library is not installed")
def test_batch_compute_log_probs_approximates_hf_lm(chat_lm: LanguageModel, hf_lm: HuggingFaceLM) -> None:
@pytest.mark.parametrize("chat_lm_name", ["chat_lm", "chat_lm_qwen"])
def test_batch_compute_log_probs_approximates_hf_lm(
request: pytest.FixtureRequest,
chat_lm_name: str,
hf_lm: HuggingFaceLM,
) -> None:
chat_lm = request.getfixturevalue(chat_lm_name)
prefix_list = ["それは正しい日本語ですか?"]
text_list = ["これは正しい日本語です。"]

Expand Down