From 65793c7d5a8c045f97d141d4055f98e5c56ccc22 Mon Sep 17 00:00:00 2001 From: Yikai Zhu Date: Tue, 21 Oct 2025 08:00:35 +0000 Subject: [PATCH 1/2] [Feature]Add Parser for Qwen3 think model --- specforge/data/parse.py | 26 ++++++++++++++++++++++++++ specforge/data/preprocessing.py | 4 +++- specforge/data/template.py | 13 +++++++++++++ 3 files changed, 42 insertions(+), 1 deletion(-) diff --git a/specforge/data/parse.py b/specforge/data/parse.py index 9b17c90a..c46e15e4 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -84,6 +84,9 @@ def parse( messages, tokenize=False, add_generation_prompt=False, **kwargs ) + if self.chat_template.ignored_token: + conversation = conversation.replace(self.chat_template.ignored_token, "") + if not self.tokenizer.pad_token_id: self.tokenizer.pad_token_id = self.tokenizer.unk_token_id @@ -122,6 +125,29 @@ def parse( return input_ids, loss_mask +class Qwen3ThinkingParser(GeneralParser): + def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): + super().__init__(tokenizer, chat_template) + + def parse( + self, + conversation: "Conversation", + max_length: int, + preformatted: bool = False, + **kwargs, + ) -> Dict[str, List[torch.Tensor]]: + if kwargs.get("enable_thinking", False): + self.assistant_message_separator = ( + f"{self.chat_template.end_of_turn_token}<|im_start|>assistant\n\n" + ) + self.chat_template.ignored_token = None + else: + self.assistant_message_separator = ( + f"{self.chat_template.end_of_turn_token}<|im_start|>assistant\n" + ) + self.chat_template.ignored_token = "\n\n\n\n" + return super().parse(conversation, max_length, preformatted, **kwargs) + class HarmonyParser(Parser): def build_single_turn_prompt( diff --git a/specforge/data/preprocessing.py b/specforge/data/preprocessing.py index 4f35d950..b4ae1a83 100644 --- a/specforge/data/preprocessing.py +++ b/specforge/data/preprocessing.py @@ -41,7 +41,7 @@ from specforge.utils import padding -from .parse import GeneralParser, HarmonyParser +from .parse import GeneralParser, HarmonyParser, Qwen3ThinkingParser from .template import TEMPLATE_REGISTRY, ChatTemplate # define a type called conversation @@ -141,6 +141,8 @@ def preprocess_conversations( if chat_template.parser_type == "general": parser = GeneralParser(tokenizer, chat_template) + elif chat_template.parser_type == "qwen3-thinking": + parser = Qwen3ThinkingParser(tokenizer, chat_template) elif chat_template.parser_type == "openai-harmony": parser = HarmonyParser(tokenizer, chat_template) else: diff --git a/specforge/data/template.py b/specforge/data/template.py index 2368048b..ddd6b9e6 100644 --- a/specforge/data/template.py +++ b/specforge/data/template.py @@ -20,6 +20,7 @@ class ChatTemplate(BaseModel): system_prompt: str | None end_of_turn_token: str | None parser_type: str = "general" + ignored_token: str | None = None class TemplateRegistry: @@ -115,6 +116,18 @@ def get_all_template_names(self) -> List[str]: ), ) +TEMPLATE_REGISTRY.register( + name="qwen3-thinking", + template=ChatTemplate( + assistant_header="<|im_start|>assistant\n", + user_header="<|im_start|>user\n", + system_prompt="You are a helpful assistant.", + end_of_turn_token="<|im_end|>\n", + parser_type="qwen3-thinking", + ignored_token="\n\n\n\n", + ), +) + TEMPLATE_REGISTRY.register( name="qwen2-vl", template=ChatTemplate( From be27c1718fed9682e23a06b4c87fc9ca2594e09f Mon Sep 17 00:00:00 2001 From: Yikai Zhu Date: Tue, 21 Oct 2025 08:10:19 +0000 Subject: [PATCH 2/2] use lint --- scripts/train_eagle3_offline.py | 23 ++++++++++++++++------- specforge/data/parse.py | 11 ++++++----- specforge/distributed.py | 2 ++ specforge/utils.py | 3 ++- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/scripts/train_eagle3_offline.py b/scripts/train_eagle3_offline.py index 5977b866..27602a1c 100644 --- a/scripts/train_eagle3_offline.py +++ b/scripts/train_eagle3_offline.py @@ -9,7 +9,7 @@ import torch.distributed as dist from accelerate.utils import set_seed from datasets import load_dataset -from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy +from torch.distributed.fsdp import MixedPrecisionPolicy, fully_shard from tqdm import tqdm from transformers import AutoTokenizer @@ -20,17 +20,22 @@ generate_vocab_mapping_file, prepare_dp_dataloaders, ) -from specforge.distributed import destroy_distributed, get_dp_group, init_distributed, get_dp_device_mesh +from specforge.distributed import ( + destroy_distributed, + get_dp_device_mesh, + get_dp_group, + init_distributed, +) from specforge.modeling.target.target_head import TargetHead from specforge.optimizer import BF16Optimizer from specforge.tracker import create_tracker, get_tracker_class from specforge.utils import ( create_draft_config_from_target, + get_full_optimizer_state, get_last_checkpoint, print_on_rank0, print_with_rank, rank_0_priority, - get_full_optimizer_state, ) @@ -340,7 +345,9 @@ def main(): length=args.ttt_length, attention_backend=args.draft_attention_backend, ) - mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) + mp_policy = MixedPrecisionPolicy( + param_dtype=torch.bfloat16, reduce_dtype=torch.float32 + ) fsdp_config = {"mesh": get_dp_device_mesh(), "mp_policy": mp_policy} fully_shard(eagle3_model, **fsdp_config) @@ -527,10 +534,12 @@ def main(): "epoch": epoch, "args": args, } - + optimizer_state_dict = optimizer.state_dict() - optimizer_state_dict["optimizer_state_dict"] = get_full_optimizer_state(optimizer_state_dict["optimizer_state_dict"]) - + optimizer_state_dict["optimizer_state_dict"] = get_full_optimizer_state( + optimizer_state_dict["optimizer_state_dict"] + ) + state_to_save.update(optimizer_state_dict) draft_model_state_dict = { diff --git a/specforge/data/parse.py b/specforge/data/parse.py index c46e15e4..7780e8b9 100644 --- a/specforge/data/parse.py +++ b/specforge/data/parse.py @@ -85,7 +85,9 @@ def parse( ) if self.chat_template.ignored_token: - conversation = conversation.replace(self.chat_template.ignored_token, "") + conversation = conversation.replace( + self.chat_template.ignored_token, "" + ) if not self.tokenizer.pad_token_id: self.tokenizer.pad_token_id = self.tokenizer.unk_token_id @@ -128,7 +130,7 @@ def parse( class Qwen3ThinkingParser(GeneralParser): def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): super().__init__(tokenizer, chat_template) - + def parse( self, conversation: "Conversation", @@ -137,9 +139,7 @@ def parse( **kwargs, ) -> Dict[str, List[torch.Tensor]]: if kwargs.get("enable_thinking", False): - self.assistant_message_separator = ( - f"{self.chat_template.end_of_turn_token}<|im_start|>assistant\n\n" - ) + self.assistant_message_separator = f"{self.chat_template.end_of_turn_token}<|im_start|>assistant\n\n" self.chat_template.ignored_token = None else: self.assistant_message_separator = ( @@ -148,6 +148,7 @@ def parse( self.chat_template.ignored_token = "\n\n\n\n" return super().parse(conversation, max_length, preformatted, **kwargs) + class HarmonyParser(Parser): def build_single_turn_prompt( diff --git a/specforge/distributed.py b/specforge/distributed.py index d61a7ae2..80713012 100644 --- a/specforge/distributed.py +++ b/specforge/distributed.py @@ -11,6 +11,7 @@ _DP_DEVICE_MESH = None _DP_GROUP = None + def get_tp_group(): global _TP_GROUP return _TP_GROUP @@ -30,6 +31,7 @@ def get_tp_device_mesh(): global _TP_DEVICE_MESH return _TP_DEVICE_MESH + def get_dp_device_mesh(): global _DP_DEVICE_MESH return _DP_DEVICE_MESH diff --git a/specforge/utils.py b/specforge/utils.py index ea6df17e..7dd5c60b 100644 --- a/specforge/utils.py +++ b/specforge/utils.py @@ -222,6 +222,7 @@ def create_draft_config_from_target( return output_path + def get_full_optimizer_state(optimizer_state_dict: dict): """ Convert optimizer state dict with DTensor to full tensors for saving @@ -246,4 +247,4 @@ def get_full_optimizer_state(optimizer_state_dict: dict): } for param_id, param_state in optimizer_state_dict["state"].items() } - return full_optimizer_state_dict \ No newline at end of file + return full_optimizer_state_dict