-
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bounded peak memory in Top-P-Top-K with chunked sorting #11544
base: main
Are you sure you want to change the base?
Bounded peak memory in Top-P-Top-K with chunked sorting #11544
Conversation
…ecoding to reduce peak memory (try to solve OOM issue in computing logits in Issue vllm-project#5907) Signed-off-by: Chenghao Yang <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@yangalan123 thanks for the investigation! I changed the we can investigate in the future, if people ever want to directly control the cc @robertgshaw2-neuralmagic I think it can solve most of issues linked in #5907 . |
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
for record, the memory cost of |
Nice work, will review on my flight today |
@@ -186,6 +186,12 @@ def __init__(self): | |||
# speculative decoding. | |||
self.include_gpu_probs_tensor = False | |||
self.should_modify_greedy_probs_inplace = False | |||
from vllm.config import get_current_vllm_config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that we should pipe this parameter down though the constructor rather than using a global config
for logits_chunk in logits_chunks: | ||
current_chunk_size = logits_chunk.shape[0] | ||
end_idx = start_idx + current_chunk_size | ||
output_logits[start_idx:end_idx].copy_( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
QQ - why .copy_
rather than just setting the value?
end_idx = start_idx + chunk_size | ||
cmp = x[start_idx:end_idx] > vals[start_idx:end_idx, None] | ||
# cmp.sum(dim=1, dtype=torch.int32) is the peak memory usage. | ||
ranks = cmp.sum(dim=1, dtype=torch.int32).add_(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment that explain the logic?
Specifically, something that says:
- select all tokens with logprobs > that the selected indicies with booleans
- sum over the booleans gets the count of Trues
- add 1
@@ -0,0 +1,20 @@ | |||
from vllm import LLM, SamplingParams | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Please move this to the logprobs test directory rather than entrypoints
- Do we have a test that this is giving the right answer?
Thanks! Left some minor comments.
|
result = (x > vals[:, None]) | ||
del vals | ||
return result.sum(1).add_(1) | ||
N, M = x.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like M
is unused, was that intentional?
This pull request has merge conflicts that must be resolved before it can be |
Thanks for the review and comments! (and Happy New Year!) As the initial author for this PR, I can provide some initial insights on the correctness of this chunked sorting approach:
# benchmark_runner.py
import torch
import time
import json
import argparse
from pathlib import Path
from tqdm import tqdm
def naive_sort(tensor: torch.Tensor) -> torch.Tensor:
return torch.sort(tensor, dim=-1)[0]
def chunked_sort(tensor: torch.Tensor, chunk_size: int) -> torch.Tensor:
chunks = torch.split(tensor, chunk_size, dim=0)
sorted_chunks = [torch.sort(chunk, dim=-1)[0] for chunk in chunks]
return torch.cat(sorted_chunks, dim=0)
def run_benchmark(args):
results = {
'time': [],
'accuracy': [],
'type': 'naive' if args.chunk_size is None else f'chunked_{args.chunk_size}'
}
for _ in tqdm(range(args.num_rounds)):
tensor = torch.randn(args.num_tokens, args.vocab_size, device='cuda')
start_time = time.perf_counter()
if args.chunk_size is None:
result = naive_sort(tensor)
else:
result = chunked_sort(tensor, args.chunk_size)
if args.check_accuracy:
naive_result = naive_sort(tensor)
accuracy = float(torch.allclose(naive_result, result, rtol=1e-5))
results['accuracy'].append(accuracy)
torch.cuda.synchronize()
results['time'].append(time.perf_counter() - start_time)
del tensor, result
torch.cuda.empty_cache()
torch.cuda.synchronize()
Path(args.output_dir).mkdir(exist_ok=True)
output_file = Path(args.output_dir) / f"results_{results['type']}.json"
with open(output_file, 'w') as f:
json.dump(results, f)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--num_tokens', type=int, default=8192)
parser.add_argument('--vocab_size', type=int, default=200000)
parser.add_argument('--chunk_size', type=int, default=None)
parser.add_argument('--num_rounds', type=int, default=5)
parser.add_argument('--check_accuracy', action='store_true')
parser.add_argument('--output_dir', type=str, default='benchmark_results')
args = parser.parse_args()
if not torch.cuda.is_available():
raise RuntimeError("CUDA not available")
run_benchmark(args)
if __name__ == "__main__":
main() I run with
Other issues on codes -- as @youkaichao is already working to polish my initial commits, perhaps he can provide more thoughts and details. Happy New Year again to everyone and thanks all for reviews and polishing this PR! |
This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you! |
This is the PR collaborated with @youkaichao for implementing chunked sorting to reduce peak memory (try to solve the OOM issue in computing logits in Issue #5907).
The issue we want to address here is that Pytorch sorting on logits would incur a significant peak memory usage, which turns out to be a major memory bottleneck during the decoding time, especially when users request to obtain logits. Our approach is that instead of sorting the large logits as a whole, we first split them into chunks and do sorting on each smaller chunk. In that way, intermediate variables that created during sorting will all become much smaller and the memory can be recycled timely.
Effectiveness Proof:
We conducted a simple verification: On a standard 4 A40 GPU server, we ran
vllm serve meta-llama/Meta-Llama-3-8B --tensor_parallel_size 4
and monitored the memory usage. We have successfully managed to reduce the peak memory usage from 5.0 GB to 4.5 GB on Rank 0 (since sampling and sorting only happened at Rank 0) by settingchunk_size=64
(1/4 ofmax_num_seqs
). Considering roughly 3.7 GB non-deductible model weights (plus some minor usage by NCCL and CUDA graphs), and the relatively small code edits, we do see this peak memory reduction as a promising direction (i.e., reducing 50% peak memory usage) to work on.FIX #5907 (link existing issues this PR will resolve)