Skip to content

Commit 30dc787

Browse files
committed
[fix] Skip logits and additional outputs handling in extra iteration for overlap scheduler
Signed-off-by: Robin Kobus <[email protected]>
1 parent 8774fd7 commit 30dc787

File tree

2 files changed

+19
-13
lines changed

2 files changed

+19
-13
lines changed

tensorrt_llm/_torch/pyexecutor/handle_additional_outputs.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import torch
55

6-
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
6+
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
7+
LlmRequestState)
78
from tensorrt_llm._utils import nvtx_range
89
from tensorrt_llm.logger import logger
910

@@ -92,18 +93,19 @@ def __call__(
9293
(1, beam_width, 1)))
9394

9495
for llm_req in generation_requests:
95-
additional_outputs = llm_req.py_additional_outputs
96+
if llm_req.state != LlmRequestState.GENERATION_COMPLETE:
97+
additional_outputs = llm_req.py_additional_outputs
9698

97-
for name in additional_outputs:
98-
outputs_begin = (output_index_with_context
99-
if gather_context[name] else
100-
output_index_without_context)
101-
outputs_end = outputs_begin + beam_width
102-
103-
output_device_view = outputs[name][
104-
outputs_begin:outputs_end].reshape(1, beam_width, -1)
105-
llm_req.py_result.append_additional_generation_outputs(
106-
name, output_device_view)
99+
for name in additional_outputs:
100+
outputs_begin = (output_index_with_context
101+
if gather_context[name] else
102+
output_index_without_context)
103+
outputs_end = outputs_begin + beam_width
104+
105+
output_device_view = outputs[name][
106+
outputs_begin:outputs_end].reshape(1, beam_width, -1)
107+
llm_req.py_result.append_additional_generation_outputs(
108+
name, output_device_view)
107109

108110
output_index_with_context += beam_width
109111
output_index_without_context += beam_width

tensorrt_llm/_torch/pyexecutor/handle_logits.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33

44
import torch
55

6-
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
6+
from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest,
7+
LlmRequestState)
78
from tensorrt_llm._utils import nvtx_range
89
from tensorrt_llm.logger import logger
910

@@ -72,6 +73,9 @@ def __call__(
7273

7374
total_context_logits = num_context_logits_prefix_sum[-1]
7475
for batch_index, llm_req in enumerate(generation_requests):
76+
if llm_req.state == LlmRequestState.GENERATION_COMPLETE:
77+
continue
78+
7579
logits_begin = total_context_logits + batch_index * beam_width
7680
logits_end = logits_begin + beam_width
7781

0 commit comments

Comments
 (0)