Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Serving] PagedKVCache Quantization #2663

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

davidpissarra
Copy link
Member

@davidpissarra davidpissarra commented Jul 16, 2024

KV Cache might be a burden under tight memory constraints, and cache quantization can reduce its memory requirement by roughly 75% (float16 -> int3 KV cache). As a result, this PR intends to introduce initial support for KV Cache quantization, introducing two new quant schemes (q3f16kv and q4f16kv).

On Llama-3 (context_window_size = 8192) kv cache memory usage:

  • 1113 MB (q4f16_1: float16 kv cache)
  • 345 MB (q4f16kv: int4 kv cache)
  • 281 MB (q3f16kv: int3 kv cache)

Depends on: apache/tvm#17159

cc @MasterJH5574 @tqchen

@XJY990705
Copy link

I want to use this pr, and I notice the TVM branch it use is f5f048b, but I got some bugs while compiling mlc-llm. Would you please tell me the TVM version you use in this pr? I'm sure the problems I got is concerned with wrong TVM version

@davidpissarra
Copy link
Member Author

davidpissarra commented Sep 5, 2024

Hi @XJY990705 , it is still actually on f5f048b. You may be able to run it if you build everything from this branch (including tvm). I will rebase it in the meantime.

@XJY990705
Copy link

@davidpissarra thank you for your reply, I will try again

@XJY990705
Copy link

I have already solved this problem by this pr tlc-pack/libflash_attn#8
maybe when I swiched to f5f048b branch, this modification is lost. Anyway, thank you for your help!!

@XJY990705
Copy link

@davidpissarra I noticed 3rdparty/tvm/src/runtime/relax_vm/paged_kv_cache.cc is not changed, and mismatched with python/mlc_llm/nn/kv_cache.py when calling TVM_REGISTER_GLOBAL("vm.builtin.paged_attention_kv_cache_create_reduced"). Is there any commit you forget?

@kazunator
Copy link

Is this still being worked on, or is there already a kv cache quantization implemented?

@davidpissarra
Copy link
Member Author

Is this still being worked on, or is there already a kv cache quantization implemented?

Hi @kazunator. It is implemented indeed. Since it hasn't been merged yet, you should be able to use it by building this branch and then using the q3f16kv or q4f16kv quantization schemes.

@kazunator
Copy link

@davidpissarra do you have example usage I can run quickly? Or just where I can specify the quantization scheme?

@davidpissarra
Copy link
Member Author

@kazunator you should be able to run it by following the typical MLC model compilation flow (the following steps should be enough). For more details, feel free refer to https://llm.mlc.ai/docs/compilation/compile_models.html

# 1. Convert model parameters
mlc_llm convert_weight ./dist/models/Meta-Llama-3-8B/ \
    --quantization q4f16kv \
    -o dist/Meta-Llama-3-8B-q4f16kv-MLC

# 2. gen_config: generate mlc-chat-config.json and process tokenizers
mlc_llm gen_config ./dist/models/Meta-Llama-3-8B/ \
    --quantization q4f16kv --conv-template llama \
    -o dist/Meta-Llama-3-8B-q4f16kv-MLC/

# 3. compile: compile model library with specification in mlc-chat-config.json
mlc_llm compile ./dist/Meta-Llama-3-8B-q4f16kv-MLC/mlc-chat-config.json \
    --device cuda -o dist/libs/Meta-Llama-3-8B-q4f16kv-cuda.so

@kazunator
Copy link

@davidpissarra oh that's cool. Is it possible to use it with just python tho? I'm trying to get it running on collab to test out your branch and I don't have a big GPU with me rn. I can test it on company A100 later in the week

@XJY990705
Copy link

Sorry to bother you again, I tried your method and I got the exact ppl for the same datasets using int3 quantization, int4 quantization and no quantization. And I want to check the kv cache value to see if it has been quantized, but add print lines in tir functions is useless(maybe removed after compilation). Do you have some options to debug and print the kv cache value? I want to print the value before and after quantization for error calculation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants