Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded peak memory in Top-P-Top-K with chunked sorting #11544

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

yangalan123
Copy link
Contributor

@yangalan123 yangalan123 commented Dec 27, 2024

This is the PR collaborated with @youkaichao for implementing chunked sorting to reduce peak memory (try to solve the OOM issue in computing logits in Issue #5907).

The issue we want to address here is that Pytorch sorting on logits would incur a significant peak memory usage, which turns out to be a major memory bottleneck during the decoding time, especially when users request to obtain logits. Our approach is that instead of sorting the large logits as a whole, we first split them into chunks and do sorting on each smaller chunk. In that way, intermediate variables that created during sorting will all become much smaller and the memory can be recycled timely.

Effectiveness Proof:
We conducted a simple verification: On a standard 4 A40 GPU server, we ran vllm serve meta-llama/Meta-Llama-3-8B --tensor_parallel_size 4 and monitored the memory usage. We have successfully managed to reduce the peak memory usage from 5.0 GB to 4.5 GB on Rank 0 (since sampling and sorting only happened at Rank 0) by setting chunk_size=64 (1/4 of max_num_seqs). Considering roughly 3.7 GB non-deductible model weights (plus some minor usage by NCCL and CUDA graphs), and the relatively small code edits, we do see this peak memory reduction as a promising direction (i.e., reducing 50% peak memory usage) to work on.

FIX #5907 (link existing issues this PR will resolve)

…ecoding to reduce peak memory (try to solve

OOM issue in computing logits in Issue vllm-project#5907)

Signed-off-by: Chenghao Yang <[email protected]>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@youkaichao youkaichao marked this pull request as draft December 27, 2024 04:33
@youkaichao youkaichao marked this pull request as ready for review December 27, 2024 05:29
@youkaichao
Copy link
Member

@yangalan123 thanks for the investigation! I changed the chunk_size budget to always be max-num-seqs , this keeps the current behavior, and only fixes the OOM issue when people ask for prompt logprobs.

we can investigate in the future, if people ever want to directly control the chunk_size . if that is the case, we can expose a flag to users.

cc @robertgshaw2-neuralmagic I think it can solve most of issues linked in #5907 .

@youkaichao youkaichao changed the title [WIP] Enable Chunked Sorting to Reduce Peak Memory Usage in Top-P-Top-K Sorting Reduce peak memory in Top-P-Top-K with chunked sorting Dec 27, 2024
@youkaichao youkaichao changed the title Reduce peak memory in Top-P-Top-K with chunked sorting Bounded peak memory in Top-P-Top-K with chunked sorting Dec 27, 2024
Signed-off-by: youkaichao <[email protected]>
Signed-off-by: youkaichao <[email protected]>
@mergify mergify bot added the ci/build label Dec 27, 2024
@youkaichao
Copy link
Member

for record, the memory cost of torch.sort seems to be a known issue , see pytorch/pytorch#77049 (comment)

@robertgshaw2-redhat robertgshaw2-redhat self-assigned this Dec 28, 2024
@robertgshaw2-redhat
Copy link
Collaborator

Nice work, will review on my flight today

@@ -186,6 +186,12 @@ def __init__(self):
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
from vllm.config import get_current_vllm_config
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should pipe this parameter down though the constructor rather than using a global config

for logits_chunk in logits_chunks:
current_chunk_size = logits_chunk.shape[0]
end_idx = start_idx + current_chunk_size
output_logits[start_idx:end_idx].copy_(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ - why .copy_ rather than just setting the value?

end_idx = start_idx + chunk_size
cmp = x[start_idx:end_idx] > vals[start_idx:end_idx, None]
# cmp.sum(dim=1, dtype=torch.int32) is the peak memory usage.
ranks = cmp.sum(dim=1, dtype=torch.int32).add_(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment that explain the logic?

Specifically, something that says:

  • select all tokens with logprobs > that the selected indicies with booleans
  • sum over the booleans gets the count of Trues
  • add 1

@@ -0,0 +1,20 @@
from vllm import LLM, SamplingParams

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Please move this to the logprobs test directory rather than entrypoints
  • Do we have a test that this is giving the right answer?

@robertgshaw2-redhat
Copy link
Collaborator

Thanks! Left some minor comments.

  • Can you run this through an lm-eval-harness test that uses prompt logprobs as a sanity check for correctness?
  • Is there any impact on speed?

result = (x > vals[:, None])
del vals
return result.sum(1).add_(1)
N, M = x.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like M is unused, was that intentional?

Copy link

mergify bot commented Dec 30, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @yangalan123.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Dec 30, 2024
@yangalan123
Copy link
Contributor Author

yangalan123 commented Dec 31, 2024

Thanks! Left some minor comments.

  • Can you run this through an lm-eval-harness test that uses prompt logprobs as a sanity check for correctness?
  • Is there any impact on speed?

Thanks for the review and comments! (and Happy New Year!) As the initial author for this PR, I can provide some initial insights on the correctness of this chunked sorting approach:

  1. Correctness: I think here prompt_logprobs is not a direct testing for this PR, though this PR is motivated by the OOM issues resulted from computing prompt_logprobs. A better and more straightforward correctness testing, is directly comparing naive torch.sort with our chunked sorting solution. I run the following very simple benchmarking codes to compare these two kinds of sorting (partial credit goes to Claude, as I am kind of lazy in holiday seasons :-) ):
# benchmark_runner.py
import torch
import time
import json
import argparse
from pathlib import Path
from tqdm import tqdm

def naive_sort(tensor: torch.Tensor) -> torch.Tensor:
    return torch.sort(tensor, dim=-1)[0]

def chunked_sort(tensor: torch.Tensor, chunk_size: int) -> torch.Tensor:
    chunks = torch.split(tensor, chunk_size, dim=0)
    sorted_chunks = [torch.sort(chunk, dim=-1)[0] for chunk in chunks]
    return torch.cat(sorted_chunks, dim=0)

def run_benchmark(args):
    results = {
        'time': [],
        'accuracy': [],
        'type': 'naive' if args.chunk_size is None else f'chunked_{args.chunk_size}'
    }

    for _ in tqdm(range(args.num_rounds)):
        tensor = torch.randn(args.num_tokens, args.vocab_size, device='cuda')
        start_time = time.perf_counter()
        if args.chunk_size is None:
            result = naive_sort(tensor)
        else:
            result = chunked_sort(tensor, args.chunk_size)
            if args.check_accuracy:
                naive_result = naive_sort(tensor)
                accuracy = float(torch.allclose(naive_result, result, rtol=1e-5))
                results['accuracy'].append(accuracy)

        torch.cuda.synchronize()
        results['time'].append(time.perf_counter() - start_time)

        del tensor, result
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

    Path(args.output_dir).mkdir(exist_ok=True)
    output_file = Path(args.output_dir) / f"results_{results['type']}.json"
    with open(output_file, 'w') as f:
        json.dump(results, f)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_tokens', type=int, default=8192)
    parser.add_argument('--vocab_size', type=int, default=200000)
    parser.add_argument('--chunk_size', type=int, default=None)
    parser.add_argument('--num_rounds', type=int, default=5)
    parser.add_argument('--check_accuracy', action='store_true')
    parser.add_argument('--output_dir', type=str, default='benchmark_results')
    args = parser.parse_args()

    if not torch.cuda.is_available():
        raise RuntimeError("CUDA not available")

    run_benchmark(args)

if __name__ == "__main__":
    main()

I run with num_tokens=8192 and vocab_size=128256 on an A40 GPU to simulate the running of Llama-3 models. The running results show that, no matter what chunk_size we choose, after 5 rounds of checking, the accuracy (in terms of sorting results match) is 100%, which verifies the correctness of our implementation. This is expected as the chunking-and-merging happens not at the dimension of vocabulary, but at the dimension of tokens.

  1. Efficiency: We would definitely face some overhead here -- because we need to do some chunking first. Also we might lose the opportunity to expose higher concurrency, as we now are sorting smaller chunks of logits and may not fully saturate the GPU computation resources. Nevertheless, reusing the above codes for benchmarking and removing extreme observed metrics in running (potentially due to shared cluster usage), I find that setting chunk_size to 32 only incurs around 0.2 more seconds (1.21 v.s. 1.02, about 18% slower). Note, here this is already a stress testing when we use the full input window (num_tokens=8192). It is definitely possible in real applications, this overhead is more negligible with smaller logit tensors to sort. I think for users suffering from OOM issues (e.g., when computing prompt_logprobs), running a bit slower is much better than just getting the CUDA OOM error and the whole running fails.

    Please note that, though we can also do a memory profiling here, for both running time and memory profiling results reported by this simple simulation program, we should take it with a grain of salt because the special distributed environments and setup at vLLM runtime are not reflected here. I made a simulation report here only for quick reference.

Other issues on codes -- as @youkaichao is already working to polish my initial commits, perhaps he can provide more thoughts and details.

Happy New Year again to everyone and thanks all for reviews and polishing this PR!

Copy link

github-actions bot commented Apr 1, 2025

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale Over 90 days of inactivity label Apr 1, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build needs-rebase stale Over 90 days of inactivity
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: TRACKING ISSUE: CUDA OOM with Logprobs
4 participants