@@ -330,4 +330,97 @@ def get_k_cache(self, layer_idx):
330330 return self .k_caches [layer_idx ]
331331
332332 def get_v_cache (self , layer_idx ):
333+ return self .v_caches [layer_idx ]
334+
335+
336+ class KHunYuanCache (nn .Module ):
337+ """
338+ HunYuan-specific cache implementation for GQA with flashinfer compatibility.
339+ Handles KV cache with proper reshaping for paged attention format.
340+ """
341+ def __init__ (
342+ self ,
343+ config : PretrainedConfig ,
344+ page_size : int = 256 ,
345+ dtype = torch .bfloat16 ,
346+ device = torch .device ("cuda:0" ),
347+ ):
348+ super ().__init__ ()
349+ self .config = config
350+ self .dtype = dtype
351+ self .device = device
352+ self .page_size = page_size
353+ self .k_caches = []
354+ self .v_caches = []
355+
356+ # HunYuan specific parameters
357+ self .num_heads = config .num_attention_heads
358+ self .num_kv_heads = config .num_key_value_heads
359+ self .head_dim = config .hidden_size // config .num_attention_heads
360+
361+
362+ def load (self , inference_context : "sched_ext.InferenceContext" ):
363+ """
364+ Load and reshape KV caches from inference context to match flashinfer format.
365+ HunYuan uses GQA with 32 attention heads and 8 KV heads.
366+ """
367+ print (f"Loading HunYuan cache for { self .config .num_hidden_layers } layers" )
368+
369+ for i in range (self .config .num_hidden_layers ):
370+ k_cache_raw = inference_context .k_cache [0 ][i ]
371+ v_cache_raw = inference_context .v_cache [0 ][i ]
372+
373+ # Check if reshaping is needed based on tensor dimensions
374+ if k_cache_raw .ndim == 2 :
375+ total_tokens = k_cache_raw .shape [0 ]
376+ num_pages = total_tokens // self .page_size
377+
378+ # Reshape k_cache: [total_tokens, kv_dim] -> [num_pages, page_size, num_kv_heads, head_dim]
379+ k_cache = k_cache_raw .view (num_pages , self .page_size , self .num_kv_heads , self .head_dim )
380+ v_cache = v_cache_raw .view (num_pages , self .page_size , self .num_kv_heads , self .head_dim )
381+ elif k_cache_raw .ndim == 3 :
382+ num_pages = k_cache_raw .shape [0 ]
383+ k_cache = k_cache_raw .view (num_pages , self .page_size , self .num_kv_heads , self .head_dim )
384+ v_cache = v_cache_raw .view (num_pages , self .page_size , self .num_kv_heads , self .head_dim )
385+ elif k_cache_raw .ndim == 4 :
386+ k_cache = k_cache_raw
387+ v_cache = v_cache_raw
388+ else :
389+ raise ValueError (f"Unexpected cache dimension: k_cache has { k_cache_raw .ndim } dimensions" )
390+
391+ self .k_caches .append (k_cache )
392+ self .v_caches .append (v_cache )
393+
394+ if len (self .k_caches ) > 0 :
395+ self .max_cache_len = self .k_caches [0 ].shape [0 ] * self .k_caches [0 ].shape [1 ]
396+ print (f"Cache loaded: shape { self .k_caches [0 ].shape } , max_cache_len { self .max_cache_len } " )
397+
398+ def get_page_table (self , cache_position : torch .Tensor , q_indptr : torch .Tensor ,
399+ kv_indptr : torch .Tensor , kv_indices : torch .Tensor , bsz_tensors : torch .tensor ):
400+ """Get page table for paged attention."""
401+ page_offset = cache_position % self .page_size
402+ page_idx_local = cache_position // self .page_size
403+ query_ids = torch .zeros_like (cache_position )
404+
405+ for i in range (len (q_indptr ) - 1 ):
406+ start_idx = q_indptr [i ]
407+ end_idx = q_indptr [i + 1 ]
408+ query_ids [start_idx :end_idx ] = i
409+
410+ page_idx = torch .zeros_like (page_idx_local )
411+ for i in range (bsz_tensors [0 ]):
412+ query_id = query_ids [i ]
413+ local_block = page_idx_local [i ]
414+ start_block = kv_indptr [query_id ]
415+ if local_block < kv_indptr [query_id + 1 ] - kv_indptr [query_id ]:
416+ page_idx [i ] = kv_indices [start_block + local_block ]
417+
418+ return page_idx , page_offset
419+
420+ def get_k_cache (self , layer_idx ):
421+ """Get k_cache for specific layer."""
422+ return self .k_caches [layer_idx ]
423+
424+ def get_v_cache (self , layer_idx ):
425+ """Get v_cache for specific layer."""
333426 return self .v_caches [layer_idx ]
0 commit comments