Skip to content

Commit

Permalink
Merge branch 'sgl-project:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
BruceXcluding authored Dec 30, 2024
2 parents 3dddac3 + 21ec66e commit ca11e11
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 12 deletions.
6 changes: 3 additions & 3 deletions python/sglang/srt/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,21 @@

@dataclasses.dataclass
class LogitsProcessorOutput:
## First part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor.
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The logits of the next tokens. shape: [#seq, vocab_size]
next_token_logits: torch.Tensor
# Used by speculative decoding (EAGLE)
# The last hidden layers
hidden_states: Optional[torch.Tensor] = None

## Second part. This part will be returned by python/sglang/srt/layers/sampler.py::Sampler.
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
# The logprobs of the next tokens. shape: [#seq]
next_token_logprobs: Optional[torch.Tensor] = None
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
next_token_top_logprobs_val: Optional[List] = None
next_token_top_logprobs_idx: Optional[List] = None

## Third part. This part will be returned by python/sglang/srt/layers/logits_processor.py::LogitsProcessor. Prefill-only.
## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
# The normlaized logprobs of prompts. shape: [#seq]
normalized_prompt_logprobs: torch.Tensor = None
# The logprobs of input tokens. shape: [#token]
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def forward(

if global_server_args_dict["sampling_backend"] == "flashinfer":
if return_logprob:
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems
# NOTE: the top_p_renorm_prob from flashinfer has numerical problems,
# https://github.com/flashinfer-ai/flashinfer/issues/708
# so we use the torch implementation.
logprobs = torch.log(
top_p_normalize_probs_torch(probs, sampling_info.top_ps)
)
Expand Down
7 changes: 3 additions & 4 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import Sampler, get_top_logprobs
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import global_server_args_dict
Expand Down Expand Up @@ -191,10 +191,9 @@ def init_torch_distributed(self):
torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda":
backend = "nccl"

# TODO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
elif self.device == "xpu":
# TODO(liangan1):Just use gloo to bypass the initilization fail
# Need to use xccl for xpu backend in the future
backend = "gloo"
elif self.device == "hpu":
backend = "hccl"
Expand Down
4 changes: 1 addition & 3 deletions python/sglang/srt/sampling/sampling_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def apply_logits_bias(self, logits: torch.Tensor):

# repetition
if self.scaling_penalties is not None:
logits = torch.where(
logits[:] = torch.where(
logits > 0,
logits / self.scaling_penalties,
logits * self.scaling_penalties,
Expand All @@ -253,5 +253,3 @@ def apply_logits_bias(self, logits: torch.Tensor):
# Apply regex vocab_mask
if self.vocab_mask is not None:
self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)

return logits
2 changes: 1 addition & 1 deletion test/srt/test_srt_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def test_logprob_grammar(self):
"regex": "( Yes| No)",
},
"return_logprob": True,
"top_logprobs_num": 5,
"top_logprobs_num": 5, # The grammar constraint allows all prefix tokens so we need to use a larger top_k.
"return_text_in_logprobs": True,
},
)
Expand Down

0 comments on commit ca11e11

Please sign in to comment.