diff --git a/src/alignrl/dpo.py b/src/alignrl/dpo.py index 779b631..79a6766 100644 --- a/src/alignrl/dpo.py +++ b/src/alignrl/dpo.py @@ -134,3 +134,4 @@ def load(self, path: Path) -> None: max_seq_length=self.config.max_seq_length, load_in_4bit=self.config.load_in_4bit, ) + ensure_chat_template(self._tokenizer) diff --git a/src/alignrl/grpo.py b/src/alignrl/grpo.py index 3eaf7da..738e110 100644 --- a/src/alignrl/grpo.py +++ b/src/alignrl/grpo.py @@ -173,3 +173,4 @@ def load(self, path: Path) -> None: max_seq_length=self.config.max_seq_length, load_in_4bit=self.config.load_in_4bit, ) + ensure_chat_template(self._tokenizer) diff --git a/src/alignrl/sft.py b/src/alignrl/sft.py index 8e1f003..045ef89 100644 --- a/src/alignrl/sft.py +++ b/src/alignrl/sft.py @@ -141,3 +141,4 @@ def load(self, path: Path) -> None: max_seq_length=self.config.max_seq_length, load_in_4bit=self.config.load_in_4bit, ) + ensure_chat_template(self._tokenizer)