Skip to content

Commit

Permalink
Raise a proper error if Jinja is missing
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Jan 6, 2025
1 parent f65e59c commit b59f023
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/transformers/utils/chat_template_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

if is_jinja_available():
import jinja2
from jinja2.ext import Extension
from jinja2.sandbox import ImmutableSandboxedEnvironment
else:
jinja2 = None

Expand Down Expand Up @@ -360,11 +362,14 @@ def _render_with_assistant_indices(

@lru_cache
def _compile_jinja_template(chat_template):
class AssistantTracker(jinja2.ext.Extension):
if not is_jinja_available():
raise ImportError("apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`.")

class AssistantTracker(Extension):
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
tags = {"generation"}

def __init__(self, environment: jinja2.sandbox.ImmutableSandboxedEnvironment):
def __init__(self, environment: ImmutableSandboxedEnvironment):
# The class is only initiated by jinja.
super().__init__(environment)
environment.extend(activate_tracker=self.activate_tracker)
Expand Down Expand Up @@ -418,7 +423,7 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
def strftime_now(format):
return datetime.now().strftime(format)

jinja_env = jinja2.sandbox.ImmutableSandboxedEnvironment(
jinja_env = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
)
jinja_env.filters["tojson"] = tojson
Expand Down

0 comments on commit b59f023

Please sign in to comment.