Skip to content

Commit 8fcbaa3

Browse files
committed
Add support for Hunyuan model
1 parent 1677e90 commit 8fcbaa3

File tree

19 files changed

+3401
-26
lines changed

19 files changed

+3401
-26
lines changed

ktransformers/local_chat.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
2929
from ktransformers.models.modeling_llama import LlamaForCausalLM
3030
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
31+
from ktransformers.models.modeling_hunyuan import HunYuanMoEV1ForCausalLM
3132
from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model
3233
from ktransformers.server.config.config import Config
3334
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
@@ -39,6 +40,7 @@
3940
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
4041
"LlamaForCausalLM": LlamaForCausalLM,
4142
"MixtralForCausalLM": MixtralForCausalLM,
43+
"HunYuanMoEV1ForCausalLM": HunYuanMoEV1ForCausalLM,
4244
}
4345

4446
ktransformer_rules_dir = (
@@ -50,6 +52,7 @@
5052
"Qwen2MoeForCausalLM": ktransformer_rules_dir + "Qwen2-57B-A14B-Instruct.yaml",
5153
"LlamaForCausalLM": ktransformer_rules_dir + "Internlm2_5-7b-Chat-1m.yaml",
5254
"MixtralForCausalLM": ktransformer_rules_dir + "Mixtral.yaml",
55+
"HunYuanMoEV1ForCausalLM": ktransformer_rules_dir + "Hunyuan-serve.yaml",
5356
}
5457

5558

@@ -96,6 +99,8 @@ def local_chat(
9699
config._attn_implementation = "eager"
97100
if "Mixtral" in config.architectures[0]:
98101
config._attn_implementation = "flash_attention_2"
102+
if "HunYuan" in config.architectures[0]:
103+
config._attn_implementation = "flash_attention_2"
99104
if torch.xpu.is_available():
100105
config._attn_implementation = "eager"
101106
model = custom_models[config.architectures[0]](config)

ktransformers/models/configuration_hunyuan.py

Lines changed: 336 additions & 0 deletions
Large diffs are not rendered by default.

ktransformers/models/custom_cache.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,4 +330,97 @@ def get_k_cache(self, layer_idx):
330330
return self.k_caches[layer_idx]
331331

332332
def get_v_cache(self, layer_idx):
333+
return self.v_caches[layer_idx]
334+
335+
336+
class KHunYuanCache(nn.Module):
337+
"""
338+
HunYuan-specific cache implementation for GQA with flashinfer compatibility.
339+
Handles KV cache with proper reshaping for paged attention format.
340+
"""
341+
def __init__(
342+
self,
343+
config: PretrainedConfig,
344+
page_size: int = 256,
345+
dtype=torch.bfloat16,
346+
device=torch.device("cuda:0"),
347+
):
348+
super().__init__()
349+
self.config = config
350+
self.dtype = dtype
351+
self.device = device
352+
self.page_size = page_size
353+
self.k_caches = []
354+
self.v_caches = []
355+
356+
# HunYuan specific parameters
357+
self.num_heads = config.num_attention_heads
358+
self.num_kv_heads = config.num_key_value_heads
359+
self.head_dim = config.hidden_size // config.num_attention_heads
360+
361+
362+
def load(self, inference_context: "sched_ext.InferenceContext"):
363+
"""
364+
Load and reshape KV caches from inference context to match flashinfer format.
365+
HunYuan uses GQA with 32 attention heads and 8 KV heads.
366+
"""
367+
print(f"Loading HunYuan cache for {self.config.num_hidden_layers} layers")
368+
369+
for i in range(self.config.num_hidden_layers):
370+
k_cache_raw = inference_context.k_cache[0][i]
371+
v_cache_raw = inference_context.v_cache[0][i]
372+
373+
# Check if reshaping is needed based on tensor dimensions
374+
if k_cache_raw.ndim == 2:
375+
total_tokens = k_cache_raw.shape[0]
376+
num_pages = total_tokens // self.page_size
377+
378+
# Reshape k_cache: [total_tokens, kv_dim] -> [num_pages, page_size, num_kv_heads, head_dim]
379+
k_cache = k_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim)
380+
v_cache = v_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim)
381+
elif k_cache_raw.ndim == 3:
382+
num_pages = k_cache_raw.shape[0]
383+
k_cache = k_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim)
384+
v_cache = v_cache_raw.view(num_pages, self.page_size, self.num_kv_heads, self.head_dim)
385+
elif k_cache_raw.ndim == 4:
386+
k_cache = k_cache_raw
387+
v_cache = v_cache_raw
388+
else:
389+
raise ValueError(f"Unexpected cache dimension: k_cache has {k_cache_raw.ndim} dimensions")
390+
391+
self.k_caches.append(k_cache)
392+
self.v_caches.append(v_cache)
393+
394+
if len(self.k_caches) > 0:
395+
self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1]
396+
print(f"Cache loaded: shape {self.k_caches[0].shape}, max_cache_len {self.max_cache_len}")
397+
398+
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor,
399+
kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
400+
"""Get page table for paged attention."""
401+
page_offset = cache_position % self.page_size
402+
page_idx_local = cache_position // self.page_size
403+
query_ids = torch.zeros_like(cache_position)
404+
405+
for i in range(len(q_indptr) - 1):
406+
start_idx = q_indptr[i]
407+
end_idx = q_indptr[i + 1]
408+
query_ids[start_idx:end_idx] = i
409+
410+
page_idx = torch.zeros_like(page_idx_local)
411+
for i in range(bsz_tensors[0]):
412+
query_id = query_ids[i]
413+
local_block = page_idx_local[i]
414+
start_block = kv_indptr[query_id]
415+
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
416+
page_idx[i] = kv_indices[start_block + local_block]
417+
418+
return page_idx, page_offset
419+
420+
def get_k_cache(self, layer_idx):
421+
"""Get k_cache for specific layer."""
422+
return self.k_caches[layer_idx]
423+
424+
def get_v_cache(self, layer_idx):
425+
"""Get v_cache for specific layer."""
333426
return self.v_caches[layer_idx]

0 commit comments

Comments
 (0)