diff --git a/ktransformers/local_chat.py b/ktransformers/local_chat.py index 928de487..973c4769 100644 --- a/ktransformers/local_chat.py +++ b/ktransformers/local_chat.py @@ -25,6 +25,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM from ktransformers.models.modeling_qwen2_moe import Qwen2MoeForCausalLM +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeForCausalLM from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM from ktransformers.models.modeling_llama import LlamaForCausalLM from ktransformers.models.modeling_mixtral import MixtralForCausalLM @@ -37,6 +38,7 @@ "DeepseekV2ForCausalLM": DeepseekV2ForCausalLM, "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, "Qwen2MoeForCausalLM": Qwen2MoeForCausalLM, + "Qwen3MoeForCausalLM": Qwen3MoeForCausalLM, "LlamaForCausalLM": LlamaForCausalLM, "MixtralForCausalLM": MixtralForCausalLM, } diff --git a/ktransformers/models/modeling_qwen3_moe.py b/ktransformers/models/modeling_qwen3_moe.py index 175f88c6..100d75fd 100644 --- a/ktransformers/models/modeling_qwen3_moe.py +++ b/ktransformers/models/modeling_qwen3_moe.py @@ -185,9 +185,10 @@ def forward( hidden_states: torch.Tensor, position_embeddings: Tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor], past_key_value: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, - # **kwargs: Unpack[FlashAttentionKwargs], + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] hidden_shape = (*input_shape, -1, self.head_dim) @@ -196,7 +197,8 @@ def forward( key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = position_embeddings + # cos, sin = position_embeddings + cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: diff --git a/ktransformers/operators/experts.py b/ktransformers/operators/experts.py index 34f0af09..163468ef 100644 --- a/ktransformers/operators/experts.py +++ b/ktransformers/operators/experts.py @@ -1494,4 +1494,108 @@ def moe_infer(self, x, topk_ids, topk_weight): .sum(dim=1) .type(new_x.dtype) ) - return final_out \ No newline at end of file + return final_out + +class KQwen3MoeSparseMoeBlock(BaseInjectedModule, Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ """ + orig_shape = hidden_states.shape + batch_size, sequence_length, hidden_dim = hidden_states.shape + hidden_states = hidden_states.view(-1, hidden_dim) + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + + routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) + routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) + if self.norm_topk_prob: + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + # we cast back to the input dtype + routing_weights = routing_weights.to(hidden_states.dtype) + + if sequence_length == 1 and hasattr(self.experts.generate_experts, "submit_for_one_decode"): + self.experts.generate_experts.submit_for_one_decode(hidden_states[0], selected_experts[0], routing_weights[0]) + # shared_expert_output = self.shared_expert(hidden_states) + # shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) + # y += shared_expert_output + y.resize_(*orig_shape) + return y, router_logits + + hidden_states_expert = hidden_states.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else hidden_states.cpu() + selected_experts_expert = selected_experts.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else selected_experts.cpu() + routing_weights_expert = routing_weights.to(self.experts.device) if isinstance(self.experts, KExpertsBase) else routing_weights.cpu() + + # shared_expert_output = self.shared_expert(hidden_states) + # shared_expert_output = ( + # F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_expert_output + # ) + + if isinstance(self.experts, KExpertsBase): + y = ( + self.moe_kexperts( + hidden_states_expert, selected_experts_expert, routing_weights_expert + ) + .view(*orig_shape) + .to(device=hidden_states.device) + ) + elif hidden_states_expert.size(0) > 10: + y = self.moe_infer( + hidden_states_expert, selected_experts_expert, routing_weights_expert, orig_shape + ).to(device=hidden_states.device) + else: + y = self.moe_infer_simple( + hidden_states_expert, selected_experts_expert, routing_weights_expert + ).to(device=hidden_states.device) + # y += shared_expert_output + y.resize_(*orig_shape) + return y, router_logits + + @torch.no_grad() + def moe_kexperts(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: + outs = self.experts(x, topk_ids, topk_weight) + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer_simple(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor) -> torch.Tensor: + ''' + hidden_states_cpu: [num_tokens, hidden_size] + topk_ids, topk_weight: [num_tokens, num_selected_experts] + ''' + outs = torch.zeros_like(hidden_states_cpu) + for token_idx in range(selected_experts_cpu.size(0)): + for expert_idx in range(selected_experts_cpu.size(1)): + expert = self.experts[selected_experts_cpu[token_idx, expert_idx]] + outs[token_idx] += expert.forward(hidden_states_cpu[token_idx]) * routing_weights_cpu[token_idx, expert_idx] + return outs + + @torch.no_grad() + # TODO may bugs here + def moe_infer(self, hidden_states_cpu: torch.Tensor, selected_experts_cpu: torch.Tensor, routing_weights_cpu: torch.Tensor, orig_shape: tuple) -> torch.Tensor: + + batch_size, sequence_length, hidden_dim = orig_shape + + final_hidden_states = torch.zeros( + (batch_size * sequence_length, hidden_dim), dtype=hidden_states_cpu.dtype, device=hidden_states_cpu.device + ) + + # One hot encode the selected experts to create an expert mask + # this will be used to easily index which expert is going to be sollicitated + expert_mask = torch.nn.functional.one_hot(selected_experts_cpu, num_classes=self.num_experts).permute(2, 1, 0) + + # Loop over all available experts in the model and perform the computation on each expert + for expert_idx in range(self.num_experts): + expert_layer = self.experts[expert_idx] + idx, top_x = torch.where(expert_mask[expert_idx]) + + # Index the correct hidden states and compute the expert hidden state for + # the current expert. We need to make sure to multiply the output hidden + # states by `routing_weights` on the corresponding tokens (top-1 and top-2) + current_state = hidden_states_cpu[None, top_x].reshape(-1, hidden_dim) + current_hidden_states = expert_layer.forward(current_state) * routing_weights_cpu[top_x, idx, None] + + # However `index_add_` only support torch tensors for indexing so we'll use + # the `top_x` tensor here. + final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states_cpu.dtype)) + + return final_hidden_states \ No newline at end of file diff --git a/ktransformers/operators/models.py b/ktransformers/operators/models.py index bbac29a3..f13387fd 100644 --- a/ktransformers/operators/models.py +++ b/ktransformers/operators/models.py @@ -65,7 +65,7 @@ LlamaRMSNorm, LlamaRotaryEmbedding, ) - +from ktransformers.models.modeling_qwen3_moe import Qwen3MoeRotaryEmbedding if is_flash_attn_2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa @@ -1350,3 +1350,289 @@ def _update_causal_mask( ) return causal_mask + +class KQwen3MoeModel(BaseInjectedModule): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`] + + Args: + config: Qwen2MoeConfig + """ + + def __init__( + self, + key: str, + gguf_loader: GGUFLoader, + config: PretrainedConfig, + orig_module: nn.Module, + device: str = "cuda", + per_layer_prefill_intput_threshold: int = 30000, # if None, no per-layer prefill + transfer_map: dict = None, + **kwargs, + ): + BaseInjectedModule.__init__( + self, key, gguf_loader, config, orig_module, device, **kwargs + ) + self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold + self.transfer_map = transfer_map + self.stream_device_map = dict() + self.rotary_emb = Qwen3MoeRotaryEmbedding(config=config) + + + @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + per_layer_prefill_intput_threshold: ( + int | None + ) = None, # if None or 0, close per-layer prefill + ) -> Union[Tuple, MoeModelOutputWithPast]: + # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}') + + if per_layer_prefill_intput_threshold is None: + per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold + per_layer_prefill_flag = False + seq_lenth = ( + inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1) + ) + if ( + per_layer_prefill_intput_threshold + and per_layer_prefill_intput_threshold < seq_lenth + ): + per_layer_prefill_flag = True + for layer in self.layers: + self.load_layer_to(layer, InferenceState.UNLOAD) + else: + pass + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_router_logits = ( + output_router_logits + if output_router_logits is not None + else self.config.output_router_logits + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + if use_cache and past_key_values is None: + past_key_values = DynamicCache() + # use_legacy_cache = False + # if use_cache and not isinstance(past_key_values, Cache): + # use_legacy_cache = True + # past_key_values = DynamicCache.from_legacy_cache(past_key_values) + # logger.warning_once( + # "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. " + # "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)" + # ) + + if inputs_embeds is None: + input_ids = input_ids.to("cpu") + inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = inputs_embeds.to("cuda") + + if cache_position is None: + past_seen_tokens = ( + past_key_values.get_seq_length() if past_key_values is not None else 0 + ) + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_embeds.shape[1], + device=inputs_embeds.device, + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values, + output_attentions, + ) + + hidden_states = inputs_embeds + + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_router_logits = () if output_router_logits else None + # next_decoder_cache = None + + for i, decoder_layer in enumerate(self.layers): + # if self.transfer_map is not None and i in self.transfer_map: + # prev_stream = torch.cuda.current_stream() + # cur_device = self.transfer_map[i] + # if cur_device not in self.stream_device_map: + # self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device) + # torch.cuda.set_device(cur_device) + # self.stream_device_map[cur_device].wait_stream(prev_stream) + # torch.cuda.set_stream(self.stream_device_map[cur_device]) + # hidden_states = hidden_states.to( + # self.transfer_map[i], non_blocking=True + # ) + # causal_mask = ( + # causal_mask.to(self.transfer_map[i], non_blocking=True) + # if causal_mask is not None + # else None + # ) + # position_ids = ( + # position_ids.to(self.transfer_map[i], non_blocking=True) + # if position_ids is not None + # else None + # ) + # cache_position = ( + # cache_position.to(self.transfer_map[i], non_blocking=True) + # if cache_position is not None + # else None + # ) + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + output_router_logits, + use_cache, + cache_position, + # position_embeddings, + ) + else: + if per_layer_prefill_flag: + # print(f"to gpu") + self.load_layer_to(decoder_layer, InferenceState.PREFILL) + torch.cuda.empty_cache() + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + output_router_logits=output_router_logits, + use_cache=use_cache, + cache_position=cache_position, + # position_embeddings=position_embeddings, + ) + if per_layer_prefill_flag: + # print(f"to cpu") + self.load_layer_to(decoder_layer, InferenceState.UNLOAD) + torch.cuda.empty_cache() + hidden_states = layer_outputs[0] + # use_cache=False + # if use_cache: + # next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + if output_router_logits and layer_outputs[-1] is not None: + all_router_logits += (layer_outputs[-1],) + + hidden_states = self.norm(hidden_states) + + if per_layer_prefill_flag: + per_layer_prefill_flag = False + for layer in self.layers: + self.load_layer_to(layer, InferenceState.GENERATE) + if output_hidden_states: + all_hidden_states += (hidden_states,) + + # next_cache = None + # if use_cache: + # next_cache = ( + # next_decoder_cache.to_legacy_cache() + # if use_legacy_cache + # else next_decoder_cache + # ) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + past_key_values, + all_hidden_states, + all_self_attns, + all_router_logits, + ] + if v is not None + ) + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values, + hidden_states=all_hidden_states, + attentions=all_self_attns, + router_logits=all_router_logits, + ) + + def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState): + assert isinstance( + layer, Qwen2MoeDecoderLayer + ), "module should be nn.ModuleList of decoder layers" + + # TODO Support restore to original device, not only cuda + device = "cpu" if target == InferenceState.UNLOAD else "cuda" + + # attn + layer.self_attn.q_proj.set_inference_mode(target) + layer.self_attn.k_proj.set_inference_mode(target) + layer.self_attn.v_proj.set_inference_mode(target) + layer.self_attn.o_proj.set_inference_mode(target) + layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device) + + # mlp + if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock): + layer.mlp.gate.set_inference_mode(target) + layer.mlp.experts.set_inference_mode(target) + layer.mlp.shared_expert.gate_proj.set_inference_mode(target) + layer.mlp.shared_expert.up_proj.set_inference_mode(target) + layer.mlp.shared_expert.down_proj.set_inference_mode(target) + layer.mlp.shared_expert.act_fn.to(device) + layer.mlp.shared_expert_gate.to(device) + else: + layer.mlp.gate_proj.set_inference_mode(target) + layer.mlp.up_proj.set_inference_mode(target) + layer.mlp.down_proj.set_inference_mode(target) + layer.mlp.act_fn.to(device) + # layer norm + layer.input_layernorm.to(device) + layer.post_attention_layernorm.to(device) \ No newline at end of file diff --git a/ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml b/ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml new file mode 100644 index 00000000..3ab73c72 --- /dev/null +++ b/ktransformers/optimize/optimize_rules/Qwen3-30B-A3B.yaml @@ -0,0 +1,83 @@ +- match: + class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding + replace: + class: ktransformers.operators.RoPE.RotaryEmbedding + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +# - match: +# name: "^model\\.layers\\..*$" # regular expression +# class: torch.nn.Linear # only match modules matching name and class simultaneously +# replace: +# class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" +# generate_op: "KLinearMarlin" +# prefill_op: "KLinearTorch" + +- match: + name: "^lm_head$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "VLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\.(?!.*mlp\\.shared_expert_gate).*$" # regular expression + class: torch.nn.Linear # only match modules matching name and class simultaneously + replace: + class: ktransformers.operators.linear.KTransformersLinear # optimized Kernel on quantized data types + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + generate_op: "KLinearMarlin" + prefill_op: "KLinearTorch" + +- match: + name: "^model\\.layers\\..*\\.mlp$" + class: ktransformers.models.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock + replace: + class: ktransformers.operators.experts.KQwen3MoeSparseMoeBlock # mlp module with custom forward function + kwargs: + generate_device: "cuda" + prefill_device: "cuda" + +- match: + name: "^model\\.layers\\..*\\.mlp\\.experts$" + replace: + class: ktransformers.operators.experts.KTransformersExperts # custom MoE Kernel with expert paralleism + kwargs: + prefill_device: "cuda" + prefill_op: "KExpertsTorch" + generate_device: "cpu" + generate_op: "KExpertsCPU" + out_device: "cuda" + recursive: False # don't recursively inject submodules of this module + +# - match: +# name: "^model$" +# replace: +# class: "ktransformers.operators.models.KQwen3MoeModel" +# kwargs: +# per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill + +- match: + name: "^model.embed_tokens" + replace: + class: "default" + kwargs: + generate_device: "cpu" + prefill_device: "cpu" +# - match: +# name: "^model\\.layers\\..*\\." +# replace: +# class: "default" +# kwargs: +# generate_device: "cuda" +# prefill_device: "cuda" \ No newline at end of file