-
Notifications
You must be signed in to change notification settings - Fork 635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] (Willing to PR) Avoid KV cache occupying GPU memory when not used #2542
Comments
Thanks for pointing this out. I am working on similar things on OpenRLHF. https://github.com/OpenRLHF/OpenRLHF Off-loading KV Cache (actually you should shut down the engine and runtime indeed) has a clear trade-off:
It saves a lot of VRAM. Cons:
In the current design of OpenRLHF, they choose not to offload to save time. In this case, the weights can be directly broadcast from the training engine to the inference engine. As shown in: api: sglang/python/sglang/srt/server.py Lines 214 to 239 in 21e9e63
test case / usage: https://github.com/sgl-project/sglang/blob/main/test/srt/test_update_weights_from_distributed.py If you feel interested, welcome to continue discussing this with me on github or on our SGLang slack. |
@zhaochenyang20 Hi thank you so much! My current thought is that, maybe we do not shutdown the engine at all, thus we do not need to slowly restart engine. Instead, some naive thoughts:
2a. As for model weights, is it even possible that we have only exactly one copy of model memory, shared by both SGLang and Transformers model (used by TRL / OpenRLHF to do weight updates)? 2b. Or, for model weights, can we do the same thing as the proposal in "1.", i.e. temporarily delete the tensors but remain all other parts, and later re-create the tensors from OpenRLHF's new weighhts? |
In my experience, I do not think 2 is possible. But maybe 1 is okay. Let me discuss this with my teammates. @fzyzcjy |
@zhaochenyang20 Thank you! I am happy to PR and try to hack it as well. |
Btw I see the PR series OpenRLHF/OpenRLHF#614 and it looks great :) I am mostly interested in PPO/REINFORCE training with fast inference engine. |
@zhaochenyang20 Quickly glanced at the code, it seems calling
_create_buffers again when we want to use sglang later) may be related.
I will do more experiments later :) |
Quick experiments about devices:
Thus for point "2b", maybe we do not need to really delete the tensors, but only need to do a For point "1", doing this instead of deleing the whole tensor may also be an alternative way. |
@fzyzcjy Hey there. I connect you through WeChat. We can have a quick discussion. |
I made a quick hack and it seems to work. CodeNote: I hack the function Change def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
print(f'hi hacky change get_weights_by_name!!! {recv_req=}')
match recv_req.name:
case 'hack_pause':
self.flush_cache()
self.token_to_kv_pool._clear_buffers()
torch.cuda.empty_cache()
case 'hack_resume':
self.token_to_kv_pool._create_buffers()
case _:
raise NotImplementedError
return None Test code import sglang as sgl
llm = sgl.Engine(
model_path="meta-llama/Llama-3.2-1B-Instruct",
# model_path='Qwen/Qwen2.5-0.5B-Instruct',
enable_torch_compile=True,
disable_cuda_graph=True,
)
print(llm)
prompts = [
"1+1=",
]
sampling_params = {"temperature": 0}
print('llm.generate #1')
outputs = llm.generate(prompts, sampling_params)
print(outputs)
print('pause')
llm.get_weights_by_name('hack_pause')
print('sleep for seconds...')
time.sleep(3)
print('resume')
llm.get_weights_by_name('hack_resume')
print('llm.generate #2')
outputs = llm.generate(prompts, sampling_params)
print(outputs) ResultGPU memory: The red color is GPU memory occupied. As we can see, during the 3 second sleep, the GPU memory is much lower, thus we seem to successfully free the kv cache pool. Logs:
Caveats
Not tested whether such "enable compile, disable CUDA graph" will cause speed slowdown. |
Another proposal: #2569 |
@fzyzcjy Hey. Amazing experiments! Funny that api: sglang/python/sglang/srt/server.py Lines 214 to 239 in 21e9e63
test case / usage: https://github.com/sgl-project/sglang/blob/main/test/srt/test_update_weights_from_distributed.py |
Also, regarding your proposed method, in real implementation, we should definitely have some functions like |
Sure! We should have separate function names like that instead of hacking an unrelated function. I am happy to PR :) |
@fzyzcjy We will PR this into OpenRLHF. Do you think trl is also needed? We (SGLang and lmsys.org) are willing to collaborate with HuggingFace also 😂 |
I guess it is up to your (SGLang and lmsys)'s choice. My personal thoughts is that, it seems many people are using TRL, so it may be great to PR to that. Btw, #2569 is not only restricted to be useful to OpenRLHF and TRL, but also for Transformers, since the proposal is about a speedup to PreTrainedModel.generate using SGLang. |
Pretty good! Let's discuss this later. @fzyzcjy |
The major concern of this approach is that cuda graph will be disabled. |
ray job submit --address="172.31.59.18:4567" \
--runtime-env-json='{"working_dir": "/opt/dlami/nvme/chenyang/rlhf-ckpt"}' \
-- python3 -m openrlhf.cli.train_ppo_ray \
--ref_num_nodes 1 \
--ref_num_gpus_per_node 1 \
--reward_num_nodes 1 \
--reward_num_gpus_per_node 1 \
--critic_num_nodes 1 \
--critic_num_gpus_per_node 1 \
--actor_num_nodes 1 \
--actor_num_gpus_per_node 1 \
--vllm_num_engines 1 \
--vllm_tensor_parallel_size 1 \
--colocate_critic_reward \
--colocate_actor_ref \
--pretrain OpenRLHF/Llama-3-8b-sft-mixture \
--reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \
--save_path /opt/dlami/nvme/chenyang/rlhf-ckpt/examples/checkpoint/llama3-8b-rlhf \
--save_steps 100 \
--micro_train_batch_size 16 \
--train_batch_size 128 \
--micro_rollout_batch_size 32 \
--rollout_batch_size 128 \
--max_samples 512 \
--max_epochs 1 \
--prompt_max_len 1024 \
--generate_max_len 1024 \
--zero_stage 3 \
--bf16 \
--actor_learning_rate 5e-7 \
--critic_learning_rate 9e-6 \
--init_kl_coef 0.01 \
--prompt_data OpenRLHF/prompt-collection-v0.1 \
--input_key context_messages \
--apply_chat_template \
--packing_samples \
--normalize_reward \
--adam_offload \
--flash_attn \
--gradient_checkpointing @fzyzcjy For your reference. When sampling with vllm: [5] NVIDIA H100 80GB HBM3 | 29°C, 0 % | 37858 / 81559 MB | chenyang(21218M) chenyang(16530M)
[6] NVIDIA H100 80GB HBM3 | 32°C, 0 % | 47336 / 81559 MB | chenyang(31702M) chenyang(15528M)
[7] NVIDIA H100 80GB HBM3 | 29°C, 0 % | 74861 / 81559 MB | chenyang(74844M) When doing weights update: [5] NVIDIA H100 80GB HBM3 | 30°C, 0 % | 57074 / 81559 MB | chenyang(40434M) chenyang(16530M)
[6] NVIDIA H100 80GB HBM3 | 33°C, 0 % | 45832 / 81559 MB | chenyang(30196M) chenyang(15530M)
[7] NVIDIA H100 80GB HBM3 | 29°C, 0 % | 74861 / 81559 MB | chenyang(74844M) |
Ah I see. CUDA graph seems not supported because our re-created tensors have different addresses. Some naive thoughts:
|
(The comment above is updated adding the "brainstorm" point just now) |
Oh man I am saying almost the same thing as @merrymercy in #2588 (review)! |
More brainstorms:
Just quick and naive brainstorms, I have not done experiments yet. |
@fzyzcjy Sorry for replying late as I have some things to do these days. If I do not reply in one day, please ping me again in WeChat. Have a good day |
@zhaochenyang20 No worries! I used a bit more time to dig deeper into the approach above and it seems to work well for SGLang: KV Cache is released and CUDA graph is still enabled. I will do more experiments later to further check when having another bit of time. Also, theoretically speaking, this approach should make the "release model weight" (again, CUDA graph can be used at the same time) very easy to implement (only two lines of code to add). But I will need to implement "update model weight from same card" feature to test that. |
Find another bit of time to submit the PR: #2630. |
Summary: When in paused mode,
And a (simple) correctness check is done: https://github.com/sgl-project/sglang/pull/2630/files#diff-4f475f1badc32fc2578207bded162aac1c915c5f9d28a8e281c1c7d20cb6dd87 - after release/resume, the generated content is correct. (numbers above are on my 4090D card, but the idea is similar; I will do more experiments later) |
Btw, some scenarios come to my mind that may be especially powerful when one uses SGLang: When the generation phase in RL is especially large. For example, do a lot of "rollout"s, or MCTS searches, at generation phase. |
This is pretty cool! If possible, we can offload both weights and the KV cache. |
@merrymercy Yes, that has already been done! (EDIT: I mean just throwing away them instead of moving to CPU, since moving is slower.) |
Quick update: On a standard 3xH100 setting (1xH100 for actor+ref, 1xH100 for critic+reward, 1xH100 for sglang), the memory of SGLang can be released when paused and becomes 2.5GB (instead of 72.2GB). Will check the 2xH100 (do not let SGLang take extra GPU) as well as comparison tests later. Metrics for very short runs are as follows: (Again, correctness has not been verified) |
Great job. Ping me on slack is needed! |
Checklist
Motivation
Hi thank you for the library! The use case is that, when doing online PPO, I hope to use SGLang to generate llm completions, and then use RL to do gradient descent on those completions.
The problem is, to do this on a single GPU, the timeline is "SGLang generate - Torch backward - repeat it". Thus, when torch doing backprop, I hope SGLang can free its KV cache memory consumption, otherwise torch will not have enough memory.
Thanks for any suggestions!
Related resources
No response
The text was updated successfully, but these errors were encountered: