From 21ec66e59e466ba8bef05478296fabfcb1f94421 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 30 Dec 2024 05:42:08 -0800 Subject: [PATCH] Minor follow-up fixes for the logprob refactor (#2670) --- python/sglang/srt/layers/logits_processor.py | 6 +++--- python/sglang/srt/layers/sampler.py | 4 +++- python/sglang/srt/model_executor/model_runner.py | 7 +++---- python/sglang/srt/sampling/sampling_batch_info.py | 4 +--- test/srt/test_srt_endpoint.py | 2 +- 5 files changed, 11 insertions(+), 12 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index ac3a4a4cc27..f6d43b67136 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -35,21 +35,21 @@ @dataclasses.dataclass class LogitsProcessorOutput: - ## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. + ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # Used by speculative decoding (EAGLE) # The last hidden layers hidden_states: Optional[torch.Tensor] = None - ## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler. + ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler # The logprobs of the next tokens. shape: [#seq] next_token_logprobs: Optional[torch.Tensor] = None # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k] next_token_top_logprobs_val: Optional[List] = None next_token_top_logprobs_idx: Optional[List] = None - ## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only. + ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor = None # The logprobs of input tokens. shape: [#token] diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index bed770e39a5..23037650a31 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -56,7 +56,9 @@ def forward( if global_server_args_dict["sampling_backend"] == "flashinfer": if return_logprob: - # NOTE: the top_p_renorm_prob from flashinfer has numerical problems + # NOTE: the top_p_renorm_prob from flashinfer has numerical problems, + # https://github.com/flashinfer-ai/flashinfer/issues/708 + # so we use the torch implementation. logprobs = torch.log( top_p_normalize_probs_torch(probs, sampling_info.top_ps) ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 67640947a57..786f654ded6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -36,7 +36,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.attention.triton_backend import TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput -from sglang.srt.layers.sampler import Sampler, get_top_logprobs +from sglang.srt.layers.sampler import Sampler from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -191,10 +191,9 @@ def init_torch_distributed(self): torch.get_device_module(self.device).set_device(self.gpu_id) if self.device == "cuda": backend = "nccl" - - # TODO(liangan1):Just use gloo to bypass the initilization fail - # Need to use xccl for xpu backend in the future elif self.device == "xpu": + # TODO(liangan1):Just use gloo to bypass the initilization fail + # Need to use xccl for xpu backend in the future backend = "gloo" elif self.device == "hpu": backend = "hccl" diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 5d4aaa41bdb..3a46b220900 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -244,7 +244,7 @@ def apply_logits_bias(self, logits: torch.Tensor): # repetition if self.scaling_penalties is not None: - logits = torch.where( + logits[:] = torch.where( logits > 0, logits / self.scaling_penalties, logits * self.scaling_penalties, @@ -253,5 +253,3 @@ def apply_logits_bias(self, logits: torch.Tensor): # Apply regex vocab_mask if self.vocab_mask is not None: self.apply_mask(logits=logits, vocab_mask=self.vocab_mask) - - return logits diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py index e974821b113..0fd71efcb0b 100644 --- a/test/srt/test_srt_endpoint.py +++ b/test/srt/test_srt_endpoint.py @@ -227,7 +227,7 @@ def test_logprob_grammar(self): "regex": "( Yes| No)", }, "return_logprob": True, - "top_logprobs_num": 5, + "top_logprobs_num": 5, # The grammar constraint allows all prefix tokens so we need to use a larger top_k. "return_text_in_logprobs": True, }, )