Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flexeval/core/chat_dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .chatbot_bench import ChatbotBench
from .openai_messages import OpenAIMessagesDataset
from .sacrebleu_dataset import SacreBleuChatDataset
from .template_based import HFChatDataset, JsonlChatDataset, TemplateChatDataset
from .template_based import HFChatDataset, JsonlChatDataset, TemplateChatDataset, load_jinja2_template
53 changes: 31 additions & 22 deletions flexeval/core/chat_dataset/template_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import json
from ast import literal_eval
from os import PathLike
from pathlib import Path
from typing import Any

import datasets
Expand All @@ -13,6 +15,13 @@
from .base import ChatDataset, ChatInstance


def load_jinja2_template(template: str | PathLike[str]) -> Template:
path = Path(template)
if path.exists():
return JINJA2_ENV.from_string(path.read_text(encoding="utf-8"))
return JINJA2_ENV.from_string(template)


class TemplateChatDataset(ChatDataset):
"""
This class only supports single-turn chat.
Expand All @@ -22,15 +31,15 @@ class TemplateChatDataset(ChatDataset):
The "tools" key for each item can contain the list of function definitions.
They should be in JSON Schema format as in the OpenAI Chat Completion API.
https://platform.openai.com/docs/guides/function-calling?api-mode=chat#defining-functions
input_template: A Jinja2 template for the user input.
input_template: A Jinja2 template for the user input. Can be template string or path to jinja2 template file.
reference_template: Specify the Jinja2 template to render the reference string
if the dataset has a single reference.
reference_list_template: Specify the Jinja2 template to render a list of reference strings
if the dataset has multiple references.
extra_info_templates: A dictionary of Jinja2 templates for extra information.
system_message_template: A Jinja2 template for the system message.
tools: Default tools to use for all chat instances. Individual items can override this
by including their own "tools" key. Typically in JSON Schema format as in the
by including their own "tools" key. Typically, in JSON Schema format as in the
OpenAI Chat Completion API for function calling.
data_range: The range of data to use.
keep_conditions: A dictionary to indicate the condition to filter certain items.
Expand All @@ -42,11 +51,11 @@ class TemplateChatDataset(ChatDataset):
def __init__(
self,
items: list[dict[str, Any]],
input_template: str,
reference_template: str | None = None,
reference_list_template: str | None = None,
extra_info_templates: dict[str, str] | None = None,
system_message_template: str | None = None,
input_template: str | PathLike[str],
reference_template: str | PathLike[str] | None = None,
reference_list_template: str | PathLike[str] | None = None,
extra_info_templates: dict[str, str | PathLike[str]] | None = None,
system_message_template: str | PathLike[str] | None = None,
tools: list[dict[str, Any]] | None = None,
data_range: tuple[int, int] | None = None,
keep_conditions: dict[str, str] | None = None,
Expand All @@ -72,19 +81,19 @@ def __init__(
self.items = items
self.tools = tools

self.input_template = JINJA2_ENV.from_string(input_template)
self.reference_template = JINJA2_ENV.from_string(reference_template) if reference_template else None
self.input_template = load_jinja2_template(input_template)
self.reference_template = load_jinja2_template(reference_template) if reference_template else None
self.reference_list_template = (
JINJA2_ENV.from_string(reference_list_template) if reference_list_template else None
load_jinja2_template(reference_list_template) if reference_list_template else None
)

extra_info_templates = extra_info_templates or {}
self._extra_info_templates: dict[str, Template] = {
key: JINJA2_ENV.from_string(template) for key, template in extra_info_templates.items()
key: load_jinja2_template(template) for key, template in extra_info_templates.items()
}

self._system_message_template: Template | None = (
JINJA2_ENV.from_string(system_message_template) if system_message_template else None
load_jinja2_template(system_message_template) if system_message_template else None
)

def __len__(self) -> int:
Expand Down Expand Up @@ -143,13 +152,13 @@ def __init__(
self,
path: str,
split: str,
input_template: str,
input_template: str | PathLike[str],
subset: str | None = None,
dataset_kwargs: dict[str, Any] | None = None,
reference_template: str | None = None,
reference_list_template: str | None = None,
extra_info_templates: dict[str, str] | None = None,
system_message_template: str | None = None,
reference_template: str | PathLike[str] | None = None,
reference_list_template: str | PathLike[str] | None = None,
extra_info_templates: dict[str, str | PathLike[str]] | None = None,
system_message_template: str | PathLike[str] | None = None,
tools: list[dict[str, Any]] | None = None,
data_range: tuple[int, int] | None = None,
keep_conditions: dict[str, str] | None = None,
Expand Down Expand Up @@ -184,11 +193,11 @@ class JsonlChatDataset(TemplateChatDataset):
def __init__(
self,
path: str,
input_template: str,
reference_template: str | None = None,
reference_list_template: str | None = None,
extra_info_templates: dict[str, str] | None = None,
system_message_template: str | None = None,
input_template: str | PathLike[str],
reference_template: str | PathLike[str] | None = None,
reference_list_template: str | PathLike[str] | None = None,
extra_info_templates: dict[str, str | PathLike[str]] | None = None,
system_message_template: str | PathLike[str] | None = None,
tools: list[dict[str, Any]] | None = None,
data_range: tuple[int, int] | None = None,
keep_conditions: dict[str, str] | None = None,
Expand Down
24 changes: 23 additions & 1 deletion tests/core/chat_dataset/test_template_based.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from __future__ import annotations

from pathlib import Path
from typing import Any

import pytest
from jinja2 import Template

from flexeval.core.chat_dataset import HFChatDataset, JsonlChatDataset, TemplateChatDataset
from flexeval.core.chat_dataset import HFChatDataset, JsonlChatDataset, TemplateChatDataset, load_jinja2_template

TOOL_DEFINITION = {
"type": "function",
Expand Down Expand Up @@ -199,3 +201,23 @@ def test_remove_conditions(
assert 0 < len(filtered_dataset) < len(original_dataset)
for item in filtered_dataset:
assert len(item.references) > 1


@pytest.fixture()
def dummy_template_file(tmp_path: Path) -> Path:
template_content = "Hello {{ name }}!"
template_file = tmp_path / "dummy.j2"
template_file.write_text(template_content, encoding="utf-8")
return template_file


def test_load_jinja2_template(dummy_template_file: Path) -> None:
template_from_path = load_jinja2_template(dummy_template_file)
embed_result = template_from_path.render(name="flexeval")
assert isinstance(template_from_path, Template)
assert embed_result == "Hello flexeval!"

template_from_string = load_jinja2_template("Hello {{ name }}!")
embed_result = template_from_string.render(name="flexeval")
assert isinstance(template_from_string, Template)
assert embed_result == "Hello flexeval!"