Skip to content

Commit

Permalink
[Fix] fix eos trim inconsistency (#1650)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ying1123 authored Oct 13, 2024
1 parent c3f2fc5 commit 4876117
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 27 deletions.
41 changes: 31 additions & 10 deletions python/sglang/srt/managers/detokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import dataclasses
import logging
from collections import OrderedDict
from typing import List
from typing import List, Union

import zmq

Expand All @@ -29,7 +29,7 @@
BatchTokenIDOut,
UpdateWeightReqOutput,
)
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import configure_logger, kill_parent_process
from sglang.utils import find_printable_text, get_exception_traceback
Expand Down Expand Up @@ -75,6 +75,21 @@ def __init__(

self.decode_status = LimitedCapacityDict()

def trim_eos(self, output: Union[str, List[int]], finished_reason, no_eos_trim):
if no_eos_trim:
return output

# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
pos = output.find(finished_reason.matched)
return output[:pos] if pos != -1 else output
if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
output, list
):
assert len(output) > 0
return output[:-1]
return output

def event_loop(self):
"""The event loop that handles requests"""

Expand Down Expand Up @@ -122,7 +137,13 @@ def event_loop(self):
s = self.decode_status[rid]
s.decode_ids = recv_obj.decode_ids[i]

read_ids.append(s.decode_ids[s.surr_offset :])
read_ids.append(
self.trim_eos(
s.decode_ids[s.surr_offset :],
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)
surr_ids.append(s.decode_ids[s.surr_offset : s.read_offset])

# TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
Expand Down Expand Up @@ -152,13 +173,13 @@ def event_loop(self):
else:
new_text = find_printable_text(new_text)

output_strs.append(s.decoded_text + new_text)

# Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
if pos != -1:
output_strs[i] = output_strs[i][:pos]
output_strs.append(
self.trim_eos(
s.decoded_text + new_text,
recv_obj.finished_reason[i],
recv_obj.no_eos_trim[i],
)
)

self.send_to_tokenizer.send_pyobj(
BatchStrOut(
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/managers/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ class BatchTokenIDOut:
spaces_between_special_tokens: List[bool]
meta_info: List[Dict]
finished_reason: List[BaseFinishReason]
no_eos_trim: List[bool]


@dataclass
Expand Down
3 changes: 3 additions & 0 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,7 @@ def handle_finished_requests(self, batch: ScheduleBatch):
output_read_offsets = []
output_skip_special_tokens = []
output_spaces_between_special_tokens = []
output_no_eos_trim = []
else: # embedding or reward model
output_embeddings = []
unfinished_indices = []
Expand Down Expand Up @@ -914,6 +915,7 @@ def handle_finished_requests(self, batch: ScheduleBatch):
output_spaces_between_special_tokens.append(
req.sampling_params.spaces_between_special_tokens
)
output_no_eos_trim.append(req.sampling_params.no_eos_trim)

meta_info = {
"prompt_tokens": len(req.origin_input_ids),
Expand Down Expand Up @@ -961,6 +963,7 @@ def handle_finished_requests(self, batch: ScheduleBatch):
output_spaces_between_special_tokens,
output_meta_info,
output_finished_reason,
output_no_eos_trim,
)
)
else: # embedding or reward model
Expand Down
49 changes: 32 additions & 17 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,23 +493,38 @@ def v1_generate_request(
top_logprobs_nums.append(
request.logprobs if request.logprobs is not None else 0
)
sampling_params_list.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"n": request.n,
"ignore_eos": request.ignore_eos,
}
)
sampling_params = []
if isinstance(request.no_eos_trim, list):
num_reqs = len(request.prompt)
else:
num_reqs = 1
for i in range(num_reqs):
sampling_params.append(
{
"temperature": request.temperature,
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"presence_penalty": request.presence_penalty,
"frequency_penalty": request.frequency_penalty,
"repetition_penalty": request.repetition_penalty,
"regex": request.regex,
"json_schema": request.json_schema,
"n": request.n,
"ignore_eos": request.ignore_eos,
"no_eos_trim": (
request.no_eos_trim
if not isinstance(request.no_eos_trim, list)
else request.no_eos_trim[i]
),
}
)
if num_reqs == 1:
sampling_params_list.append(sampling_params[0])
else:
sampling_params_list.append(sampling_params)

if len(all_requests) == 1:
prompt = prompts[0]
Expand Down
1 change: 1 addition & 0 deletions python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class CompletionRequest(BaseModel):
min_tokens: int = 0
repetition_penalty: Optional[float] = 1.0
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
no_eos_trim: Union[bool, List[bool]] = False


class CompletionResponseChoice(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
regex: Optional[str] = None,
n: int = 1,
json_schema: Optional[str] = None,
no_eos_trim: bool = False,
) -> None:
self.temperature = temperature
self.top_p = top_p
Expand All @@ -60,6 +61,7 @@ def __init__(
self.regex = regex
self.n = n
self.json_schema = json_schema
self.no_eos_trim = no_eos_trim

# Process some special cases
if self.temperature < _SAMPLING_EPS:
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,3 +690,10 @@ def pytorch_profile(name, func, *args, data_size=-1):
prof.export_chrome_trace(f"trace/{name}_{step_counter}.json")
step_counter += 1
return result


def first_rank_print(*args, **kwargs):
if torch.cuda.current_device() == 0:
print(*args, **kwargs)
else:
pass

0 comments on commit 4876117

Please sign in to comment.