Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions specforge/data/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def parse(
messages, tokenize=False, add_generation_prompt=False, **kwargs
)

if self.chat_template.ignored_token:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a check to ensure self.chat_template.ignored_token is not None or empty before calling replace. This prevents potential errors if the token is not defined.

if self.chat_template.ignored_token and self.chat_template.ignored_token != "":
    conversation = conversation.replace(self.chat_template.ignored_token, "")
Suggested change
if self.chat_template.ignored_token:
if self.chat_template.ignored_token and self.chat_template.ignored_token != "":
conversation = conversation.replace(self.chat_template.ignored_token, "")

conversation = conversation.replace(self.chat_template.ignored_token, "")

if not self.tokenizer.pad_token_id:
self.tokenizer.pad_token_id = self.tokenizer.unk_token_id

Expand Down Expand Up @@ -122,6 +125,29 @@ def parse(
return input_ids, loss_mask


class Qwen3ThinkingParser(GeneralParser):
def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
super().__init__(tokenizer, chat_template)

def parse(
self,
conversation: "Conversation",
max_length: int,
preformatted: bool = False,
**kwargs,
) -> Dict[str, List[torch.Tensor]]:
if kwargs.get("enable_thinking", False):
self.assistant_message_separator = (
f"{self.chat_template.end_of_turn_token}<|im_start|>assistant\n<think>\n"
)
self.chat_template.ignored_token = None
else:
self.assistant_message_separator = (
f"{self.chat_template.end_of_turn_token}<|im_start|>assistant\n"
)
self.chat_template.ignored_token = "<think>\n\n</think>\n\n"
return super().parse(conversation, max_length, preformatted, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The super().parse method returns input_ids and loss_mask. However, the return type annotation Dict[str, List[torch.Tensor]] suggests a dictionary is expected. This discrepancy could lead to confusion or errors if the caller expects a dictionary. Consider updating the return type annotation or modifying the return statement to return a dictionary.

Suggestion:

return {"input_ids": input_ids, "loss_mask": loss_mask}

Alternatively, update the return type annotation to Tuple[torch.Tensor, torch.Tensor] to match the actual return type.

Suggested change
return super().parse(conversation, max_length, preformatted, **kwargs)
return {"input_ids": input_ids, "loss_mask": loss_mask}


class HarmonyParser(Parser):

def build_single_turn_prompt(
Expand Down
4 changes: 3 additions & 1 deletion specforge/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from specforge.utils import padding

from .parse import GeneralParser, HarmonyParser
from .parse import GeneralParser, HarmonyParser, Qwen3ThinkingParser
from .template import TEMPLATE_REGISTRY, ChatTemplate

# define a type called conversation
Expand Down Expand Up @@ -141,6 +141,8 @@ def preprocess_conversations(

if chat_template.parser_type == "general":
parser = GeneralParser(tokenizer, chat_template)
elif chat_template.parser_type == "qwen3-thinking":
parser = Qwen3ThinkingParser(tokenizer, chat_template)
elif chat_template.parser_type == "openai-harmony":
parser = HarmonyParser(tokenizer, chat_template)
else:
Expand Down
13 changes: 13 additions & 0 deletions specforge/data/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class ChatTemplate(BaseModel):
system_prompt: str | None
end_of_turn_token: str | None
parser_type: str = "general"
ignored_token: str | None = None


class TemplateRegistry:
Expand Down Expand Up @@ -115,6 +116,18 @@ def get_all_template_names(self) -> List[str]:
),
)

TEMPLATE_REGISTRY.register(
name="qwen3-thinking",
template=ChatTemplate(
assistant_header="<|im_start|>assistant\n",
user_header="<|im_start|>user\n",
system_prompt="You are a helpful assistant.",
end_of_turn_token="<|im_end|>\n",
parser_type="qwen3-thinking",
ignored_token="<think>\n\n</think>\n\n",
),
)

TEMPLATE_REGISTRY.register(
name="qwen2-vl",
template=ChatTemplate(
Expand Down
Loading