diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b563c96343f..d34adde8352 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -107,7 +107,8 @@ steps: - vllm/ commands: - pip install -e ./plugins/vllm_add_dummy_model - - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py + - pytest -v -s entrypoints/llm --ignore=entrypoints/llm/test_lazy_outlines.py --ignore=entrypoints/llm/test_generate.py --ignore=entrypoints/llm/test_generate_multiple_loras.py --ignore=entrypoints/llm/test_guided_generate.py --ignore=entrypoints/llm/test_prompt_logprobs.py + - pytest -v -s entrypoints/llm/test_prompt_logprobs.py # it needs a clean process - pytest -v -s entrypoints/llm/test_lazy_outlines.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process diff --git a/tests/entrypoints/llm/test_prompt_logprobs.py b/tests/entrypoints/llm/test_prompt_logprobs.py new file mode 100644 index 00000000000..964ffcccabc --- /dev/null +++ b/tests/entrypoints/llm/test_prompt_logprobs.py @@ -0,0 +1,20 @@ +from vllm import LLM, SamplingParams + + +def test_prompt_logprobs(): + llm = LLM(model="meta-llama/Meta-Llama-3-8B") + + # stress the system by asking for prompt logprobs with a long prompt + sampling_params = SamplingParams(top_p=0.9, + top_k=50, + temperature=0.8, + prompt_logprobs=10, + max_tokens=1) + # right now we use chunked sort and chunked logprobs to reduce + # the peak memory, it reduces the peak memory, however, they cannot + # make sure runtime peak memory <= profiling peak memory. + # To fully solve this issue (i.e. we can use 8192 to test prompt logprobs), + # we need to make sure the whole sampling process is chunked. + token_ids = list(range(1024)) + # make sure sorting does not cause OOM + llm.generate(prompt_token_ids=token_ids, sampling_params=sampling_params) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c2d12c466ba..83fdedc2521 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -186,6 +186,12 @@ def __init__(self): # speculative decoding. self.include_gpu_probs_tensor = False self.should_modify_greedy_probs_inplace = False + from vllm.config import get_current_vllm_config + vllm_config = get_current_vllm_config() + # we sample at most max_num_seqs sequences during profiling, + # here we remember the value so that we can limit the number of + # sequences for sampling during runtime to control peak memory usage + self.max_sampling_tokens = vllm_config.scheduler_config.max_num_seqs def _init_sampling_tensors( self, @@ -271,8 +277,32 @@ def forward( logits.div_(sampling_tensors.temperatures.unsqueeze(dim=1)) if do_top_p_top_k and flashinfer_top_k_top_p_sampling is None: - logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, - sampling_tensors.top_ks) + if logits.shape[0] <= self.max_sampling_tokens: + # for the most common case, every sequence only has one token + # for sampling, and the total number of tokens to sample + # is less than `self.max_sampling_tokens` + logits = _apply_top_k_top_p(logits, sampling_tensors.top_ps, + sampling_tensors.top_ks) + else: + # when prompt_logprobs are required, the number of tokens + # to sample can be larger than `self.max_sampling_tokens`, + # and the memory footprint can be very large. We split the + # operation into chunks to control the peak memory usage during + # runtime. + chunk_size = self.max_sampling_tokens + output_logits = torch.empty_like(logits) + logits_chunks = torch.split(logits, chunk_size, dim=0) + start_idx = 0 + for logits_chunk in logits_chunks: + current_chunk_size = logits_chunk.shape[0] + end_idx = start_idx + current_chunk_size + output_logits[start_idx:end_idx].copy_( + _apply_top_k_top_p( + logits_chunk, + sampling_tensors.top_ps[start_idx:end_idx], + sampling_tensors.top_ks[start_idx:end_idx])) + start_idx = end_idx + logits = output_logits if do_min_p: logits = _apply_min_p(logits, sampling_tensors.min_ps) @@ -312,7 +342,8 @@ def forward( assert not isinstance(maybe_deferred_sample_results, SampleResultArgsType) prompt_logprobs, sample_logprobs = get_logprobs( - logprobs, sampling_metadata, maybe_deferred_sample_results) + logprobs, sampling_metadata, maybe_deferred_sample_results, + self.max_sampling_tokens) return _build_sampler_output( maybe_deferred_sample_results, @@ -858,7 +889,8 @@ def _sample( ) -def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: +def _get_ranks(x: torch.Tensor, indices: torch.Tensor, + max_sampling_tokens: int) -> torch.Tensor: """ This function calculates the ranks of the chosen tokens in a logprob tensor. @@ -866,23 +898,36 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: x (torch.Tensor): 2D logprob tensor of shape (N, M) where N is the no. of tokens and M is the vocab dim. indices (torch.Tensor): List of chosen token indices. + max_sampling_tokens (int): The maximum number of tokens to calculate + in one kernel launch to control the peak memory usage. Returns: torch.Tensor: 1D tensor of shape (N,) where N is the no. of tokens. Each element in the returned tensor represents the rank of the chosen token in the input logprob tensor. """ - vals = x[torch.arange(0, len(x), device=x.device, dtype=indices.dtype), - indices] - result = (x > vals[:, None]) - del vals - return result.sum(1).add_(1) + N, M = x.shape + vals = x[torch.arange(0, N, device=x.device, dtype=indices.dtype), indices] + final_result = torch.empty(N, device=x.device, dtype=indices.dtype) + result_chunks = torch.chunk(final_result, max_sampling_tokens, dim=0) + start_idx = 0 + for chunk in result_chunks: + chunk_size = chunk.size(0) + end_idx = start_idx + chunk_size + cmp = x[start_idx:end_idx] > vals[start_idx:end_idx, None] + # cmp.sum(dim=1, dtype=torch.int32) is the peak memory usage. + ranks = cmp.sum(dim=1, dtype=torch.int32).add_(1) + chunk.copy_(ranks) + del cmp, ranks + start_idx = end_idx + return final_result def get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sample_results: SampleResultType, + max_sampling_tokens: int, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: """Return sample logprobs and prompt logprobs. @@ -977,6 +1022,7 @@ def get_logprobs( ranks = _get_ranks( logprobs[query_indices_gpu], next_token_ids_gpu, + max_sampling_tokens, ) assert selected_logprobs.shape[0] == ranks.shape[0]