Skip to content

Commit b59f023

Browse files
committed
Raise a proper error if Jinja is missing
1 parent f65e59c commit b59f023

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/transformers/utils/chat_template_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828

2929
if is_jinja_available():
3030
import jinja2
31+
from jinja2.ext import Extension
32+
from jinja2.sandbox import ImmutableSandboxedEnvironment
3133
else:
3234
jinja2 = None
3335

@@ -360,11 +362,14 @@ def _render_with_assistant_indices(
360362

361363
@lru_cache
362364
def _compile_jinja_template(chat_template):
363-
class AssistantTracker(jinja2.ext.Extension):
365+
if not is_jinja_available():
366+
raise ImportError("apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`.")
367+
368+
class AssistantTracker(Extension):
364369
# This extension is used to track the indices of assistant-generated tokens in the rendered chat
365370
tags = {"generation"}
366371

367-
def __init__(self, environment: jinja2.sandbox.ImmutableSandboxedEnvironment):
372+
def __init__(self, environment: ImmutableSandboxedEnvironment):
368373
# The class is only initiated by jinja.
369374
super().__init__(environment)
370375
environment.extend(activate_tracker=self.activate_tracker)
@@ -418,7 +423,7 @@ def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False)
418423
def strftime_now(format):
419424
return datetime.now().strftime(format)
420425

421-
jinja_env = jinja2.sandbox.ImmutableSandboxedEnvironment(
426+
jinja_env = ImmutableSandboxedEnvironment(
422427
trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
423428
)
424429
jinja_env.filters["tojson"] = tojson

0 commit comments

Comments
 (0)