|
3 | 3 |
|
4 | 4 | import torch |
5 | 5 |
|
6 | | -from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest |
| 6 | +from tensorrt_llm._torch.pyexecutor.llm_request import (LlmRequest, |
| 7 | + LlmRequestState) |
7 | 8 | from tensorrt_llm._utils import nvtx_range |
8 | 9 | from tensorrt_llm.logger import logger |
9 | 10 |
|
@@ -92,18 +93,19 @@ def __call__( |
92 | 93 | (1, beam_width, 1))) |
93 | 94 |
|
94 | 95 | for llm_req in generation_requests: |
95 | | - additional_outputs = llm_req.py_additional_outputs |
| 96 | + if llm_req.state != LlmRequestState.GENERATION_COMPLETE: |
| 97 | + additional_outputs = llm_req.py_additional_outputs |
96 | 98 |
|
97 | | - for name in additional_outputs: |
98 | | - outputs_begin = (output_index_with_context |
99 | | - if gather_context[name] else |
100 | | - output_index_without_context) |
101 | | - outputs_end = outputs_begin + beam_width |
102 | | - |
103 | | - output_device_view = outputs[name][ |
104 | | - outputs_begin:outputs_end].reshape(1, beam_width, -1) |
105 | | - llm_req.py_result.append_additional_generation_outputs( |
106 | | - name, output_device_view) |
| 99 | + for name in additional_outputs: |
| 100 | + outputs_begin = (output_index_with_context |
| 101 | + if gather_context[name] else |
| 102 | + output_index_without_context) |
| 103 | + outputs_end = outputs_begin + beam_width |
| 104 | + |
| 105 | + output_device_view = outputs[name][ |
| 106 | + outputs_begin:outputs_end].reshape(1, beam_width, -1) |
| 107 | + llm_req.py_result.append_additional_generation_outputs( |
| 108 | + name, output_device_view) |
107 | 109 |
|
108 | 110 | output_index_with_context += beam_width |
109 | 111 | output_index_without_context += beam_width |
|
0 commit comments