-
Notifications
You must be signed in to change notification settings - Fork 288
cpu kv cache support quanted kv. #1133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @hiworldwzj, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the KV cache management system by introducing comprehensive support for quantized KV caches on the CPU. It involves a strategic refactoring of memory manager components into a dedicated package, alongside crucial updates to the GPU-CPU KV cache transfer mechanisms to correctly handle quantization scales. These changes enable dynamic selection of memory managers tailored to specific model architectures and quantization schemes, ultimately leading to more efficient memory utilization and improved performance for models leveraging quantized KV caches. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for quantized key-value (KV) cache on the CPU, which is a significant enhancement. The changes are well-structured, including a major refactoring of memory management modules into a dedicated kv_cache_mem_manager package. The Triton kernels for data transfer between GPU and CPU have been updated to handle quantization scales, and the logic for calculating CPU cache metadata is now more modular and aware of different quantization schemes. Overall, the implementation is solid. I have a few suggestions to improve code clarity and maintainability, particularly regarding the Triton kernels and metadata calculation logic.
| gpu_scale_stride = [0 for _ in range(10)] | ||
| cpu_scale_stride = [0 for _ in range(10)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of [0 for _ in range(10)] to initialize gpu_scale_stride and cpu_scale_stride is not very clear, as the number 10 seems arbitrary. It would be more readable and maintainable to use the exact number of dimensions for each tensor's stride. The GPU scale tensor has 4 dimensions and the CPU scale tensor has 5.
| gpu_scale_stride = [0 for _ in range(10)] | |
| cpu_scale_stride = [0 for _ in range(10)] | |
| gpu_scale_stride = (0,) * 4 | |
| cpu_scale_stride = (0,) * 5 |
| gpu_scale_stride = [0 for _ in range(10)] | ||
| cpu_scale_stride = [0 for _ in range(10)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the offload function, the use of [0 for _ in range(10)] here is unclear. Using explicit dimension counts would be better for readability and maintainability. The GPU scale tensor has 4 dimensions and the CPU scale tensor has 5.
| gpu_scale_stride = [0 for _ in range(10)] | |
| cpu_scale_stride = [0 for _ in range(10)] | |
| gpu_scale_stride = (0,) * 4 | |
| cpu_scale_stride = (0,) * 5 |
| mem_manager_class = select_mem_manager_class() | ||
| if mem_manager_class is Deepseek2MemoryManager: | ||
| cpu_cache_meta = CpuKVCacheMeta( | ||
| page_num=0, | ||
| token_page_size=args.cpu_cache_token_page_size, | ||
| layer_num=get_layer_num(args.model_dir), | ||
| num_heads=1, | ||
| head_dim=512 + 64, | ||
| data_type=get_llm_data_type(), | ||
| scale_head_dim=0, | ||
| scale_data_type=get_llm_data_type(), | ||
| ) | ||
| elif mem_manager_class is Deepseek2FP8KVMemoryManager: | ||
| cpu_cache_meta = CpuKVCacheMeta( | ||
| page_num=0, | ||
| token_page_size=args.cpu_cache_token_page_size, | ||
| layer_num=get_layer_num(args.model_dir), | ||
| num_heads=1, | ||
| head_dim=512 + 64 + 2, | ||
| data_type=torch.uint8, | ||
| scale_head_dim=0, | ||
| scale_data_type=get_llm_data_type(), | ||
| ) | ||
| elif mem_manager_class is MemoryManager: | ||
| cpu_cache_meta = CpuKVCacheMeta( | ||
| page_num=0, | ||
| token_page_size=args.cpu_cache_token_page_size, | ||
| layer_num=get_layer_num(args.model_dir), | ||
| num_heads=get_num_key_value_heads(args.model_dir) * 2, | ||
| head_dim=get_head_dim(args.model_dir), | ||
| data_type=get_llm_data_type(), | ||
| scale_head_dim=0, | ||
| scale_data_type=get_llm_data_type(), | ||
| ) | ||
| elif mem_manager_class is PPLINT8KVMemoryManager: | ||
| cpu_cache_meta = CpuKVCacheMeta( | ||
| page_num=0, | ||
| token_page_size=args.cpu_cache_token_page_size, | ||
| layer_num=get_layer_num(args.model_dir), | ||
| num_heads=get_num_key_value_heads(args.model_dir) * 2, | ||
| head_dim=get_head_dim(args.model_dir), | ||
| data_type=torch.int8, | ||
| scale_head_dim=get_head_dim(args.model_dir) // 8, | ||
| scale_data_type=get_llm_data_type(), | ||
| ) | ||
| else: | ||
| item_size = 2 | ||
| num_key_value_heads = get_num_key_value_heads(args.model_dir) * 2 | ||
| head_dim = get_head_dim(args.model_dir) | ||
| layer_num = get_layer_num(args.model_dir) | ||
| logger.error(f"not support mem manager: {mem_manager_class} for cpu kv cache") | ||
| raise Exception(f"not support mem manager: {mem_manager_class} for cpu kv cache") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The if/elif chain to determine cpu_cache_meta based on mem_manager_class is not easily extensible. Every time a new memory manager is added, this function needs to be modified. A better design would be to move this logic into the memory manager classes themselves, for example, by adding a static method like get_cpu_cache_meta(args) to each memory manager class. This would make the system more modular and easier to maintain.
Example:
# In mem_manager.py and other manager classes
class MemoryManager:
@staticmethod
def get_cpu_cache_meta(args):
# logic to create CpuKVCacheMeta
...
# In kv_cache_utils.py
def calcu_cpu_cache_meta() -> "CpuKVCacheMeta":
args = get_env_start_args()
assert args.enable_cpu_cache
mem_manager_class = select_mem_manager_class()
cpu_cache_meta = mem_manager_class.get_cpu_cache_meta(args)
# ... common logic for mtp_mode and page_num calculation
...
No description provided.