Skip to content

Commit cd8f8a8

Browse files
committed
[fix] Generation logits length for overlap scheduler early exit
Signed-off-by: Robin Kobus <[email protected]>
1 parent 25051c0 commit cd8f8a8

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,19 @@ class LogitsStorage:
4343

4444
def __init__(
4545
self,
46+
*,
4647
seq_length: int,
4748
use_device_memory=True,
48-
should_exclude_last=False,
49+
extra_token_for_overlap_scheduler=False,
4950
use_chunked_generation_logits=False,
5051
chunk_size=8
5152
): # logic adpted from HandleGenerationLogits.cpp to use chunked transfer
52-
if should_exclude_last:
53+
if extra_token_for_overlap_scheduler:
5354
# Exclude last logits is used when overlap scheduler is used, that generates one extra token,
5455
# so we should make sure there's memory for that extra +1.
5556
seq_length += 1
5657
self.seq_length = seq_length
5758
self.use_device_memory = use_device_memory
58-
self._should_exclude_last = should_exclude_last
5959
self.use_chunked_generation_logits = use_chunked_generation_logits
6060
self.chunk_size = chunk_size
6161
self._logits_indices = []
@@ -126,14 +126,14 @@ def append(self, logits: torch.Tensor):
126126
non_blocking=True)
127127
self._logits_indices.append((position, new_position))
128128

129-
def get(self, all_logits: bool) -> torch.Tensor | None:
129+
def get(self, all_logits: bool, exclude_last: bool) -> torch.Tensor | None:
130130
"""Returns the used logits storage if there are any, otherwise, returns None.
131131
When all_logits is True then all set logits are returned, otherwise, only the last logits are returned."""
132132
if self._storage is None:
133133
return None
134134

135135
try:
136-
last = -2 if self._should_exclude_last else -1
136+
last = -2 if exclude_last else -1
137137
start = 0 if all_logits else self._logits_indices[last][0]
138138
end = self._logits_indices[last][1]
139139
return self._storage[start:end]
@@ -175,9 +175,6 @@ def finalize_chunked_transfer(self):
175175
if self.use_chunked_generation_logits and self._device_fragments:
176176
self._transfer_chunk_to_host()
177177

178-
def set_exclude_last(self, should_exclude_last: bool) -> None:
179-
self._should_exclude_last = should_exclude_last
180-
181178

182179
class LogProbStorage:
183180
beam_width: int = -1
@@ -225,6 +222,7 @@ class PyResult:
225222
"""PyResult reimplements some features of `bindings.executor.Result` in Python"""
226223

227224
def __init__(self,
225+
*,
228226
prompt_len: int,
229227
max_new_tokens: int,
230228
use_device_memory=True,
@@ -240,16 +238,20 @@ def __init__(self,
240238
assert chunk_size == 1, "chunk_size must be 1 in streaming mode"
241239
self._streaming = streaming
242240
self._chunk_size = chunk_size
241+
self._exclude_last_generation_logits = exclude_last_generation_logits
243242

244243
# Note that in C++ implemnetation both context logits and generation logits are stored on host memory.
245244
# Here we only use host memory for generation logits if in chunked model.
246245
self._context_logits = LogitsStorage(
247-
prompt_len, use_device_memory, use_chunked_generation_logits=False
246+
seq_length=prompt_len,
247+
use_device_memory=use_device_memory,
248+
extra_token_for_overlap_scheduler=False,
249+
use_chunked_generation_logits=False
248250
) if return_context_logits else None
249251
self._generation_logits = LogitsStorage(
250-
max_new_tokens,
251-
use_device_memory,
252-
exclude_last_generation_logits,
252+
seq_length=max_new_tokens,
253+
use_device_memory=use_device_memory,
254+
extra_token_for_overlap_scheduler=exclude_last_generation_logits,
253255
use_chunked_generation_logits=use_chunked_generation_logits,
254256
chunk_size=self._chunk_size) if return_generation_logits else None
255257
self._log_probs = LogProbStorage() if return_log_probs else None
@@ -263,6 +265,10 @@ def __init__(self,
263265
for name in additional_outputs
264266
} if additional_outputs else None
265267

268+
def set_exclude_last_generation_logits(
269+
self, exclude_last_generation_logits: bool):
270+
self._exclude_last_generation_logits = exclude_last_generation_logits
271+
266272
def append_context_logits(self, context_logits: torch.Tensor):
267273
if self._context_logits:
268274
self._context_logits.append(context_logits)
@@ -309,7 +315,7 @@ def set_log_probs(self, log_probs: list[TokenLogprobs],
309315
@property
310316
def context_logits(self) -> torch.Tensor | None:
311317
if self._context_logits is None or (storage := self._context_logits.get(
312-
all_logits=True)) is None:
318+
all_logits=True, exclude_last=False)) is None:
313319
return None
314320
return storage[:, 0] # remove beam_width axis for context
315321

@@ -320,7 +326,9 @@ def generation_logits(self) -> torch.Tensor | None:
320326
if not self._generation_logits:
321327
return None
322328

323-
storage = self._generation_logits.get(all_logits=not self._streaming)
329+
storage = self._generation_logits.get(
330+
all_logits=not self._streaming,
331+
exclude_last=self._exclude_last_generation_logits)
324332
if storage is None:
325333
return None
326334
return storage.transpose(0, 1)
@@ -522,14 +530,14 @@ def __init__(
522530
self.py_stop_words_list = stop_words_list
523531

524532
self.py_result = PyResult(
525-
self.py_prompt_len,
526-
self.py_max_new_tokens,
527-
return_logits_device_memory,
528-
self.streaming,
529-
return_log_probs,
530-
return_context_logits,
531-
return_generation_logits,
532-
exclude_last_generation_logits,
533+
prompt_len=self.py_prompt_len,
534+
max_new_tokens=self.py_max_new_tokens,
535+
use_device_memory=return_logits_device_memory,
536+
streaming=self.streaming,
537+
return_log_probs=return_log_probs,
538+
return_context_logits=return_context_logits,
539+
return_generation_logits=return_generation_logits,
540+
exclude_last_generation_logits=exclude_last_generation_logits,
533541
use_chunked_generation_logits=self.py_use_chunked_generation_logits,
534542
chunk_size=self.py_logits_chunk_size,
535543
additional_outputs=additional_outputs)
@@ -543,6 +551,11 @@ def __init__(
543551
else:
544552
self._py_embedding_bias_1d = self.embedding_bias
545553

554+
def set_exclude_last_generation_logits(
555+
self, exclude_last_generation_logits: bool):
556+
self.py_result.set_exclude_last_generation_logits(
557+
exclude_last_generation_logits)
558+
546559
@property
547560
def cached_tokens(self) -> int:
548561
return self._cached_tokens

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1883,6 +1883,7 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
18831883
if request.context_remaining_length == 0:
18841884
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
18851885
):
1886+
request.set_exclude_last_generation_logits(False)
18861887
request.state = LlmRequestState.GENERATION_TO_COMPLETE
18871888
else:
18881889
request.state = LlmRequestState.GENERATION_IN_PROGRESS
@@ -1891,6 +1892,7 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
18911892
if request.state != LlmRequestState.GENERATION_COMPLETE:
18921893
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
18931894
):
1895+
request.set_exclude_last_generation_logits(False)
18941896
request.state = LlmRequestState.GENERATION_TO_COMPLETE
18951897

18961898
def _update_request_states_star_attention(

0 commit comments

Comments
 (0)