Skip to content

Conversation

@hiworldwzj
Copy link
Collaborator

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the KV cache management system by introducing comprehensive support for quantized KV caches on the CPU. It involves a strategic refactoring of memory manager components into a dedicated package, alongside crucial updates to the GPU-CPU KV cache transfer mechanisms to correctly handle quantization scales. These changes enable dynamic selection of memory managers tailored to specific model architectures and quantization schemes, ultimately leading to more efficient memory utilization and improved performance for models leveraging quantized KV caches.

Highlights

  • Refactored KV Cache Memory Management: All KV cache memory manager classes have been consolidated into a new, dedicated package: lightllm/common/kv_cache_mem_manager/. This improves organization and modularity for different KV cache implementations.
  • Support for Quantized KV Cache on CPU: The core functionality for offloading and loading KV cache between GPU and CPU has been extended to explicitly support quantized KV caches. This includes modifications to Triton kernels to handle quantization scales (gpu_kv_cache_scale, cpu_kv_cache_scale) during data transfer.
  • Dynamic Memory Manager Selection: The system now dynamically selects the appropriate KV cache memory manager based on the model type and configuration. This includes specialized managers for Deepseek2 models and PPLINT8KV/PPLINT4KV, enhancing flexibility and efficiency.
  • Flexible CPU KV Cache Metadata: The CpuKVCacheMeta structure has been updated to accommodate various data types, head dimensions, and scale dimensions, allowing for accurate calculation and management of quantized KV cache sizes and layouts on the CPU.
  • Centralized Data Type Management: A new utility function, get_llm_data_type(), has been introduced to centralize the retrieval of the LLM's data type from environment arguments, promoting consistency across the codebase.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for quantized key-value (KV) cache on the CPU, which is a significant enhancement. The changes are well-structured, including a major refactoring of memory management modules into a dedicated kv_cache_mem_manager package. The Triton kernels for data transfer between GPU and CPU have been updated to handle quantization scales, and the logic for calculating CPU cache metadata is now more modular and aware of different quantization schemes. Overall, the implementation is solid. I have a few suggestions to improve code clarity and maintainability, particularly regarding the Triton kernels and metadata calculation logic.

Comment on lines +308 to +309
gpu_scale_stride = [0 for _ in range(10)]
cpu_scale_stride = [0 for _ in range(10)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of [0 for _ in range(10)] to initialize gpu_scale_stride and cpu_scale_stride is not very clear, as the number 10 seems arbitrary. It would be more readable and maintainable to use the exact number of dimensions for each tensor's stride. The GPU scale tensor has 4 dimensions and the CPU scale tensor has 5.

Suggested change
gpu_scale_stride = [0 for _ in range(10)]
cpu_scale_stride = [0 for _ in range(10)]
gpu_scale_stride = (0,) * 4
cpu_scale_stride = (0,) * 5

Comment on lines +658 to +659
gpu_scale_stride = [0 for _ in range(10)]
cpu_scale_stride = [0 for _ in range(10)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the offload function, the use of [0 for _ in range(10)] here is unclear. Using explicit dimension counts would be better for readability and maintainability. The GPU scale tensor has 4 dimensions and the CPU scale tensor has 5.

Suggested change
gpu_scale_stride = [0 for _ in range(10)]
cpu_scale_stride = [0 for _ in range(10)]
gpu_scale_stride = (0,) * 4
cpu_scale_stride = (0,) * 5

Comment on lines +64 to +111
mem_manager_class = select_mem_manager_class()
if mem_manager_class is Deepseek2MemoryManager:
cpu_cache_meta = CpuKVCacheMeta(
page_num=0,
token_page_size=args.cpu_cache_token_page_size,
layer_num=get_layer_num(args.model_dir),
num_heads=1,
head_dim=512 + 64,
data_type=get_llm_data_type(),
scale_head_dim=0,
scale_data_type=get_llm_data_type(),
)
elif mem_manager_class is Deepseek2FP8KVMemoryManager:
cpu_cache_meta = CpuKVCacheMeta(
page_num=0,
token_page_size=args.cpu_cache_token_page_size,
layer_num=get_layer_num(args.model_dir),
num_heads=1,
head_dim=512 + 64 + 2,
data_type=torch.uint8,
scale_head_dim=0,
scale_data_type=get_llm_data_type(),
)
elif mem_manager_class is MemoryManager:
cpu_cache_meta = CpuKVCacheMeta(
page_num=0,
token_page_size=args.cpu_cache_token_page_size,
layer_num=get_layer_num(args.model_dir),
num_heads=get_num_key_value_heads(args.model_dir) * 2,
head_dim=get_head_dim(args.model_dir),
data_type=get_llm_data_type(),
scale_head_dim=0,
scale_data_type=get_llm_data_type(),
)
elif mem_manager_class is PPLINT8KVMemoryManager:
cpu_cache_meta = CpuKVCacheMeta(
page_num=0,
token_page_size=args.cpu_cache_token_page_size,
layer_num=get_layer_num(args.model_dir),
num_heads=get_num_key_value_heads(args.model_dir) * 2,
head_dim=get_head_dim(args.model_dir),
data_type=torch.int8,
scale_head_dim=get_head_dim(args.model_dir) // 8,
scale_data_type=get_llm_data_type(),
)
else:
item_size = 2
num_key_value_heads = get_num_key_value_heads(args.model_dir) * 2
head_dim = get_head_dim(args.model_dir)
layer_num = get_layer_num(args.model_dir)
logger.error(f"not support mem manager: {mem_manager_class} for cpu kv cache")
raise Exception(f"not support mem manager: {mem_manager_class} for cpu kv cache")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/elif chain to determine cpu_cache_meta based on mem_manager_class is not easily extensible. Every time a new memory manager is added, this function needs to be modified. A better design would be to move this logic into the memory manager classes themselves, for example, by adding a static method like get_cpu_cache_meta(args) to each memory manager class. This would make the system more modular and easier to maintain.

Example:

# In mem_manager.py and other manager classes
class MemoryManager:
    @staticmethod
    def get_cpu_cache_meta(args):
        # logic to create CpuKVCacheMeta
        ...

# In kv_cache_utils.py
def calcu_cpu_cache_meta() -> "CpuKVCacheMeta":
    args = get_env_start_args()
    assert args.enable_cpu_cache
    mem_manager_class = select_mem_manager_class()
    cpu_cache_meta = mem_manager_class.get_cpu_cache_meta(args)
    # ... common logic for mtp_mode and page_num calculation
    ...

wangzaijun added 3 commits December 4, 2025 05:30
@hiworldwzj hiworldwzj merged commit a41492e into main Dec 4, 2025
1 check passed
@hiworldwzj hiworldwzj deleted the wzj branch December 4, 2025 06:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants