diff --git a/flexeval/core/chat_dataset/__init__.py b/flexeval/core/chat_dataset/__init__.py index a1904f51..be31bd26 100644 --- a/flexeval/core/chat_dataset/__init__.py +++ b/flexeval/core/chat_dataset/__init__.py @@ -2,4 +2,4 @@ from .chatbot_bench import ChatbotBench from .openai_messages import OpenAIMessagesDataset from .sacrebleu_dataset import SacreBleuChatDataset -from .template_based import HFChatDataset, JsonlChatDataset, TemplateChatDataset +from .template_based import HFChatDataset, JsonlChatDataset, TemplateChatDataset, load_jinja2_template diff --git a/flexeval/core/chat_dataset/template_based.py b/flexeval/core/chat_dataset/template_based.py index 768d8b42..fa5d8189 100644 --- a/flexeval/core/chat_dataset/template_based.py +++ b/flexeval/core/chat_dataset/template_based.py @@ -2,6 +2,8 @@ import json from ast import literal_eval +from os import PathLike +from pathlib import Path from typing import Any import datasets @@ -13,6 +15,13 @@ from .base import ChatDataset, ChatInstance +def load_jinja2_template(template: str | PathLike[str]) -> Template: + path = Path(template) + if path.exists(): + return JINJA2_ENV.from_string(path.read_text(encoding="utf-8")) + return JINJA2_ENV.from_string(template) + + class TemplateChatDataset(ChatDataset): """ This class only supports single-turn chat. @@ -22,7 +31,7 @@ class TemplateChatDataset(ChatDataset): The "tools" key for each item can contain the list of function definitions. They should be in JSON Schema format as in the OpenAI Chat Completion API. https://platform.openai.com/docs/guides/function-calling?api-mode=chat#defining-functions - input_template: A Jinja2 template for the user input. + input_template: A Jinja2 template for the user input. Can be template string or path to jinja2 template file. reference_template: Specify the Jinja2 template to render the reference string if the dataset has a single reference. reference_list_template: Specify the Jinja2 template to render a list of reference strings @@ -30,7 +39,7 @@ class TemplateChatDataset(ChatDataset): extra_info_templates: A dictionary of Jinja2 templates for extra information. system_message_template: A Jinja2 template for the system message. tools: Default tools to use for all chat instances. Individual items can override this - by including their own "tools" key. Typically in JSON Schema format as in the + by including their own "tools" key. Typically, in JSON Schema format as in the OpenAI Chat Completion API for function calling. data_range: The range of data to use. keep_conditions: A dictionary to indicate the condition to filter certain items. @@ -42,11 +51,11 @@ class TemplateChatDataset(ChatDataset): def __init__( self, items: list[dict[str, Any]], - input_template: str, - reference_template: str | None = None, - reference_list_template: str | None = None, - extra_info_templates: dict[str, str] | None = None, - system_message_template: str | None = None, + input_template: str | PathLike[str], + reference_template: str | PathLike[str] | None = None, + reference_list_template: str | PathLike[str] | None = None, + extra_info_templates: dict[str, str | PathLike[str]] | None = None, + system_message_template: str | PathLike[str] | None = None, tools: list[dict[str, Any]] | None = None, data_range: tuple[int, int] | None = None, keep_conditions: dict[str, str] | None = None, @@ -72,19 +81,19 @@ def __init__( self.items = items self.tools = tools - self.input_template = JINJA2_ENV.from_string(input_template) - self.reference_template = JINJA2_ENV.from_string(reference_template) if reference_template else None + self.input_template = load_jinja2_template(input_template) + self.reference_template = load_jinja2_template(reference_template) if reference_template else None self.reference_list_template = ( - JINJA2_ENV.from_string(reference_list_template) if reference_list_template else None + load_jinja2_template(reference_list_template) if reference_list_template else None ) extra_info_templates = extra_info_templates or {} self._extra_info_templates: dict[str, Template] = { - key: JINJA2_ENV.from_string(template) for key, template in extra_info_templates.items() + key: load_jinja2_template(template) for key, template in extra_info_templates.items() } self._system_message_template: Template | None = ( - JINJA2_ENV.from_string(system_message_template) if system_message_template else None + load_jinja2_template(system_message_template) if system_message_template else None ) def __len__(self) -> int: @@ -143,13 +152,13 @@ def __init__( self, path: str, split: str, - input_template: str, + input_template: str | PathLike[str], subset: str | None = None, dataset_kwargs: dict[str, Any] | None = None, - reference_template: str | None = None, - reference_list_template: str | None = None, - extra_info_templates: dict[str, str] | None = None, - system_message_template: str | None = None, + reference_template: str | PathLike[str] | None = None, + reference_list_template: str | PathLike[str] | None = None, + extra_info_templates: dict[str, str | PathLike[str]] | None = None, + system_message_template: str | PathLike[str] | None = None, tools: list[dict[str, Any]] | None = None, data_range: tuple[int, int] | None = None, keep_conditions: dict[str, str] | None = None, @@ -184,11 +193,11 @@ class JsonlChatDataset(TemplateChatDataset): def __init__( self, path: str, - input_template: str, - reference_template: str | None = None, - reference_list_template: str | None = None, - extra_info_templates: dict[str, str] | None = None, - system_message_template: str | None = None, + input_template: str | PathLike[str], + reference_template: str | PathLike[str] | None = None, + reference_list_template: str | PathLike[str] | None = None, + extra_info_templates: dict[str, str | PathLike[str]] | None = None, + system_message_template: str | PathLike[str] | None = None, tools: list[dict[str, Any]] | None = None, data_range: tuple[int, int] | None = None, keep_conditions: dict[str, str] | None = None, diff --git a/tests/core/chat_dataset/test_template_based.py b/tests/core/chat_dataset/test_template_based.py index 4f823c68..ba02f266 100644 --- a/tests/core/chat_dataset/test_template_based.py +++ b/tests/core/chat_dataset/test_template_based.py @@ -1,10 +1,12 @@ from __future__ import annotations +from pathlib import Path from typing import Any import pytest +from jinja2 import Template -from flexeval.core.chat_dataset import HFChatDataset, JsonlChatDataset, TemplateChatDataset +from flexeval.core.chat_dataset import HFChatDataset, JsonlChatDataset, TemplateChatDataset, load_jinja2_template TOOL_DEFINITION = { "type": "function", @@ -199,3 +201,23 @@ def test_remove_conditions( assert 0 < len(filtered_dataset) < len(original_dataset) for item in filtered_dataset: assert len(item.references) > 1 + + +@pytest.fixture() +def dummy_template_file(tmp_path: Path) -> Path: + template_content = "Hello {{ name }}!" + template_file = tmp_path / "dummy.j2" + template_file.write_text(template_content, encoding="utf-8") + return template_file + + +def test_load_jinja2_template(dummy_template_file: Path) -> None: + template_from_path = load_jinja2_template(dummy_template_file) + embed_result = template_from_path.render(name="flexeval") + assert isinstance(template_from_path, Template) + assert embed_result == "Hello flexeval!" + + template_from_string = load_jinja2_template("Hello {{ name }}!") + embed_result = template_from_string.render(name="flexeval") + assert isinstance(template_from_string, Template) + assert embed_result == "Hello flexeval!"