-
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Bounded peak memory in Top-P-Top-K with chunked sorting #11544
base: main
Are you sure you want to change the base?
Changes from all commits
cb07a82
3bd4e7d
9ec13ab
88c1a70
9f7ff0b
3448789
69347e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from vllm import LLM, SamplingParams | ||
|
||
|
||
def test_prompt_logprobs(): | ||
llm = LLM(model="meta-llama/Meta-Llama-3-8B") | ||
|
||
# stress the system by asking for prompt logprobs with a long prompt | ||
sampling_params = SamplingParams(top_p=0.9, | ||
top_k=50, | ||
temperature=0.8, | ||
prompt_logprobs=10, | ||
max_tokens=1) | ||
# right now we use chunked sort and chunked logprobs to reduce | ||
# the peak memory, it reduces the peak memory, however, they cannot | ||
# make sure runtime peak memory <= profiling peak memory. | ||
# To fully solve this issue (i.e. we can use 8192 to test prompt logprobs), | ||
# we need to make sure the whole sampling process is chunked. | ||
token_ids = list(range(1024)) | ||
# make sure sorting does not cause OOM | ||
llm.generate(prompt_token_ids=token_ids, sampling_params=sampling_params) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -186,6 +186,12 @@ def __init__(self): | |
# speculative decoding. | ||
self.include_gpu_probs_tensor = False | ||
self.should_modify_greedy_probs_inplace = False | ||
from vllm.config import get_current_vllm_config | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that we should pipe this parameter down though the constructor rather than using a global config |
||
vllm_config = get_current_vllm_config() | ||
# we sample at most max_num_seqs sequences during profiling, | ||
# here we remember the value so that we can limit the number of | ||
# sequences for sampling during runtime to control peak memory usage | ||
self.max_sampling_tokens = vllm_config.scheduler_config.max_num_seqs | ||
|
||
def _init_sampling_tensors( | ||
self, | ||
|
@@ -271,8 +277,32 @@ def forward( | |
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) | ||
|
||
if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: | ||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, | ||
sampling_tensors.top_ks) | ||
if logits.shape[0] <= self.max_sampling_tokens: | ||
# for the most common case, every sequence only has one token | ||
# for sampling, and the total number of tokens to sample | ||
# is less than `self.max_sampling_tokens` | ||
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, | ||
sampling_tensors.top_ks) | ||
else: | ||
# when prompt_logprobs are required, the number of tokens | ||
# to sample can be larger than `self.max_sampling_tokens`, | ||
# and the memory footprint can be very large. We split the | ||
# operation into chunks to control the peak memory usage during | ||
# runtime. | ||
chunk_size = self.max_sampling_tokens | ||
output_logits = torch.empty_like(logits) | ||
logits_chunks = torch.split(logits, chunk_size, dim=0) | ||
start_idx = 0 | ||
for logits_chunk in logits_chunks: | ||
current_chunk_size = logits_chunk.shape[0] | ||
end_idx = start_idx + current_chunk_size | ||
output_logits[start_idx:end_idx].copy_( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. QQ - why |
||
_apply_top_k_top_p( | ||
logits_chunk, | ||
sampling_tensors.top_ps[start_idx:end_idx], | ||
sampling_tensors.top_ks[start_idx:end_idx])) | ||
start_idx = end_idx | ||
logits = output_logits | ||
|
||
if do_min_p: | ||
logits = _apply_min_p(logits, sampling_tensors.min_ps) | ||
|
@@ -312,7 +342,8 @@ def forward( | |
assert not isinstance(maybe_deferred_sample_results, | ||
SampleResultArgsType) | ||
prompt_logprobs, sample_logprobs = get_logprobs( | ||
logprobs, sampling_metadata, maybe_deferred_sample_results) | ||
logprobs, sampling_metadata, maybe_deferred_sample_results, | ||
self.max_sampling_tokens) | ||
|
||
return _build_sampler_output( | ||
maybe_deferred_sample_results, | ||
|
@@ -858,31 +889,45 @@ def _sample( | |
) | ||
|
||
|
||
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: | ||
def _get_ranks(x: torch.Tensor, indices: torch.Tensor, | ||
max_sampling_tokens: int) -> torch.Tensor: | ||
""" | ||
This function calculates the ranks of the chosen tokens in a logprob tensor. | ||
|
||
Args: | ||
x (torch.Tensor): 2D logprob tensor of shape (N, M) | ||
where N is the no. of tokens and M is the vocab dim. | ||
indices (torch.Tensor): List of chosen token indices. | ||
max_sampling_tokens (int): The maximum number of tokens to calculate | ||
in one kernel launch to control the peak memory usage. | ||
|
||
Returns: | ||
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. | ||
Each element in the returned tensor represents the rank | ||
of the chosen token in the input logprob tensor. | ||
""" | ||
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), | ||
indices] | ||
result = (x > vals[:, None]) | ||
del vals | ||
return result.sum(1).add_(1) | ||
N, M = x.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like |
||
vals = x[torch.arange(0, N, device=x.device, dtype=indices.dtype), indices] | ||
final_result = torch.empty(N, device=x.device, dtype=indices.dtype) | ||
result_chunks = torch.chunk(final_result, max_sampling_tokens, dim=0) | ||
start_idx = 0 | ||
for chunk in result_chunks: | ||
chunk_size = chunk.size(0) | ||
end_idx = start_idx + chunk_size | ||
cmp = x[start_idx:end_idx] > vals[start_idx:end_idx, None] | ||
# cmp.sum(dim=1, dtype=torch.int32) is the peak memory usage. | ||
ranks = cmp.sum(dim=1, dtype=torch.int32).add_(1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add a comment that explain the logic? Specifically, something that says:
|
||
chunk.copy_(ranks) | ||
del cmp, ranks | ||
start_idx = end_idx | ||
return final_result | ||
|
||
|
||
def get_logprobs( | ||
logprobs: torch.Tensor, | ||
sampling_metadata: SamplingMetadata, | ||
sample_results: SampleResultType, | ||
max_sampling_tokens: int, | ||
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: | ||
"""Return sample logprobs and prompt logprobs. | ||
|
||
|
@@ -977,6 +1022,7 @@ def get_logprobs( | |
ranks = _get_ranks( | ||
logprobs[query_indices_gpu], | ||
next_token_ids_gpu, | ||
max_sampling_tokens, | ||
) | ||
assert selected_logprobs.shape[0] == ranks.shape[0] | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.