Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded peak memory in Top-P-Top-K with chunked sorting #11544

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ steps:
- vllm/
commands:
- pip install -e ./plugins/vllm_add_dummy_model
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py
- pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_prompt_logprobs.py
- pytest -v -s entrypoints/llm/test_prompt_logprobs.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
Expand Down
20 changes: 20 additions & 0 deletions tests/entrypoints/llm/test_prompt_logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from vllm import LLM, SamplingParams

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Please move this to the logprobs test directory rather than entrypoints
  • Do we have a test that this is giving the right answer?


def test_prompt_logprobs():
llm = LLM(model="meta-llama/Meta-Llama-3-8B")

# stress the system by asking for prompt logprobs with a long prompt
sampling_params = SamplingParams(top_p=0.9,
top_k=50,
temperature=0.8,
prompt_logprobs=10,
max_tokens=1)
# right now we use chunked sort and chunked logprobs to reduce
# the peak memory, it reduces the peak memory, however, they cannot
# make sure runtime peak memory <= profiling peak memory.
# To fully solve this issue (i.e. we can use 8192 to test prompt logprobs),
# we need to make sure the whole sampling process is chunked.
token_ids = list(range(1024))
# make sure sorting does not cause OOM
llm.generate(prompt_token_ids=token_ids, sampling_params=sampling_params)
64 changes: 55 additions & 9 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,12 @@ def __init__(self):
# speculative decoding.
self.include_gpu_probs_tensor = False
self.should_modify_greedy_probs_inplace = False
from vllm.config import get_current_vllm_config
Copy link
Collaborator

@robertgshaw2-redhat robertgshaw2-redhat Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that we should pipe this parameter down though the constructor rather than using a global config

vllm_config = get_current_vllm_config()
# we sample at most max_num_seqs sequences during profiling,
# here we remember the value so that we can limit the number of
# sequences for sampling during runtime to control peak memory usage
self.max_sampling_tokens = vllm_config.scheduler_config.max_num_seqs

def _init_sampling_tensors(
self,
Expand Down Expand Up @@ -271,8 +277,32 @@ def forward(
logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1))

if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None:
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
if logits.shape[0] <= self.max_sampling_tokens:
# for the most common case, every sequence only has one token
# for sampling, and the total number of tokens to sample
# is less than `self.max_sampling_tokens`
logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps,
sampling_tensors.top_ks)
else:
# when prompt_logprobs are required, the number of tokens
# to sample can be larger than `self.max_sampling_tokens`,
# and the memory footprint can be very large. We split the
# operation into chunks to control the peak memory usage during
# runtime.
chunk_size = self.max_sampling_tokens
output_logits = torch.empty_like(logits)
logits_chunks = torch.split(logits, chunk_size, dim=0)
start_idx = 0
for logits_chunk in logits_chunks:
current_chunk_size = logits_chunk.shape[0]
end_idx = start_idx + current_chunk_size
output_logits[start_idx:end_idx].copy_(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ - why .copy_ rather than just setting the value?

_apply_top_k_top_p(
logits_chunk,
sampling_tensors.top_ps[start_idx:end_idx],
sampling_tensors.top_ks[start_idx:end_idx]))
start_idx = end_idx
logits = output_logits

if do_min_p:
logits = _apply_min_p(logits, sampling_tensors.min_ps)
Expand Down Expand Up @@ -312,7 +342,8 @@ def forward(
assert not isinstance(maybe_deferred_sample_results,
SampleResultArgsType)
prompt_logprobs, sample_logprobs = get_logprobs(
logprobs, sampling_metadata, maybe_deferred_sample_results)
logprobs, sampling_metadata, maybe_deferred_sample_results,
self.max_sampling_tokens)

return _build_sampler_output(
maybe_deferred_sample_results,
Expand Down Expand Up @@ -858,31 +889,45 @@ def _sample(
)


def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
def _get_ranks(x: torch.Tensor, indices: torch.Tensor,
max_sampling_tokens: int) -> torch.Tensor:
"""
This function calculates the ranks of the chosen tokens in a logprob tensor.

Args:
x (torch.Tensor): 2D logprob tensor of shape (N, M)
where N is the no. of tokens and M is the vocab dim.
indices (torch.Tensor): List of chosen token indices.
max_sampling_tokens (int): The maximum number of tokens to calculate
in one kernel launch to control the peak memory usage.

Returns:
torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens.
Each element in the returned tensor represents the rank
of the chosen token in the input logprob tensor.
"""
vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype),
indices]
result = (x > vals[:, None])
del vals
return result.sum(1).add_(1)
N, M = x.shape
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like M is unused, was that intentional?

vals = x[torch.arange(0, N, device=x.device, dtype=indices.dtype), indices]
final_result = torch.empty(N, device=x.device, dtype=indices.dtype)
result_chunks = torch.chunk(final_result, max_sampling_tokens, dim=0)
start_idx = 0
for chunk in result_chunks:
chunk_size = chunk.size(0)
end_idx = start_idx + chunk_size
cmp = x[start_idx:end_idx] > vals[start_idx:end_idx, None]
# cmp.sum(dim=1, dtype=torch.int32) is the peak memory usage.
ranks = cmp.sum(dim=1, dtype=torch.int32).add_(1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment that explain the logic?

Specifically, something that says:

  • select all tokens with logprobs > that the selected indicies with booleans
  • sum over the booleans gets the count of Trues
  • add 1

chunk.copy_(ranks)
del cmp, ranks
start_idx = end_idx
return final_result


def get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sample_results: SampleResultType,
max_sampling_tokens: int,
) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]:
"""Return sample logprobs and prompt logprobs.

Expand Down Expand Up @@ -977,6 +1022,7 @@ def get_logprobs(
ranks = _get_ranks(
logprobs[query_indices_gpu],
next_token_ids_gpu,
max_sampling_tokens,
)
assert selected_logprobs.shape[0] == ranks.shape[0]

Expand Down
Loading