Skip to content

Commit b3b5034

Browse files
committed
[None][feat] Enable early exit with overlap scheduler
- Update MicroBatchScheduler bindings to skip scheduling after GENERATION_TO_COMPLETE state. - Update PyExecutor to set GENERATION_TO_COMPLETE state for requests that will complete next iteration. - Fix _executor_loop_overlap to finish previous batch if current batch is empty. Signed-off-by: Robin Kobus <[email protected]>
1 parent 0019d99 commit b3b5034

File tree

6 files changed

+39
-22
lines changed

6 files changed

+39
-22
lines changed

cpp/tensorrt_llm/batch_manager/llmRequest.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,10 @@ void LlmRequest::createSerializedResult(
6969
/// Note that there is some dependency on the order of operations in this method. Modify with care!
7070
std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank)
7171
{
72-
if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)))
72+
if (!(isFinished()
73+
|| (mIsStreaming
74+
&& (mState == LlmRequestState::kGENERATION_IN_PROGRESS
75+
|| mState == LlmRequestState::kGENERATION_TO_COMPLETE))))
7376
{
7477
return std::nullopt;
7578
}

cpp/tensorrt_llm/nanobind/batch_manager/algorithms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void tensorrt_llm::nanobind::batch_manager::algorithms::initBindings(nb::module_
6464
LlmRequestState>(),
6565
nb::arg("ctx_chunk_config") = std::nullopt, nb::arg("max_context_length") = std::nullopt,
6666
nb::arg("no_schedule_until_state") = LlmRequestState::kCONTEXT_INIT,
67-
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_COMPLETE)
67+
nb::arg("no_schedule_after_state") = LlmRequestState::kGENERATION_TO_COMPLETE)
6868
.def("__call__", &MicroBatchScheduler::operator(), nb::arg("active_requests"), nb::arg("inflight_req_ids"),
6969
nb::arg("max_batch_size_runtime"), nb::arg("max_num_tokens_runtime"))
7070
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ void initBindings(nb::module_& m)
103103
.def("get_last_tokens", nb::overload_cast<>(&GenLlmReq::getLastTokens))
104104
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, nb::arg("for_next_iteration") = false)
105105
.def_prop_ro("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
106+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
106107
.def("add_new_token", &GenLlmReq::addNewToken, nb::arg("token"), nb::arg("beam"))
107108
.def("add_new_tokens", &GenLlmReq::addNewTokens, nb::arg("beam_tokens"))
108109
.def_prop_ro("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::mod
6565
LlmRequestState>(),
6666
py::arg("ctx_chunk_config") = std::nullopt, py::arg("max_context_length") = std::nullopt,
6767
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
68-
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
69-
"LlmRequestState.GENERATION_COMPLETE"))
68+
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_TO_COMPLETE,
69+
"LlmRequestState.GENERATION_TO_COMPLETE"))
7070
.def("__call__", &MicroBatchScheduler::operator(), py::arg("active_requests"), py::arg("inflight_req_ids"),
7171
py::arg("max_batch_size_runtime"), py::arg("max_num_tokens_runtime"))
7272
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });

cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ void initBindings(pybind11::module_& m)
107107
.def("get_last_tokens", py::overload_cast<>(&GenLlmReq::getLastTokens))
108108
.def("get_beam_width_by_iter", &GenLlmReq::getBeamWidthByIter, py::arg("for_next_iteration") = false)
109109
.def_property_readonly("max_num_generated_tokens", &GenLlmReq::getMaxNumGeneratedTokens)
110+
.def("will_complete_next_iteration", &GenLlmReq::willCompleteNextIteration)
110111
.def("add_new_token", &GenLlmReq::addNewToken, py::arg("token"), py::arg("beam"))
111112
.def("add_new_tokens", &GenLlmReq::addNewTokens, py::arg("beam_tokens"))
112113
.def_property_readonly("num_draft_tokens", &GenLlmReq::getNumDraftTokens)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -853,7 +853,7 @@ def _executor_loop_pp(self):
853853
self.num_scheduled_requests = scheduled_batch.batch_size
854854

855855
logger.debug(
856-
f'has {len(self.active_requests)} active_request, '
856+
f'has {len(self.active_requests)} active_requests, '
857857
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
858858
f'{len(scheduled_batch.generation_requests)} generation requests'
859859
)
@@ -1094,7 +1094,7 @@ def _prepare_and_schedule_batch(self):
10941094

10951095
self.num_scheduled_requests = scheduled_batch.batch_size
10961096
logger.debug(
1097-
f'has {len(self.active_requests)} active_request, '
1097+
f'has {len(self.active_requests)} active_requests, '
10981098
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
10991099
f'{len(scheduled_batch.generation_requests)} generation requests')
11001100
return scheduled_batch, iter_stats
@@ -1374,21 +1374,22 @@ def _executor_loop_overlap(self):
13741374
if target_inputs is not None:
13751375
self._process_draft_results(scheduled_batch,
13761376
draft_outputs, draft_batch)
1377-
elif self.previous_batch is not None and not use_previous_draft_tokens:
1378-
self._update_requests(self.previous_batch.sample_state)
1377+
if target_inputs is None and self.previous_batch is not None and not use_previous_draft_tokens:
1378+
self._update_requests(self.previous_batch.sample_state)
13791379

1380-
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
1381-
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
1382-
if req.is_context_only_request and (
1383-
req.is_context_finished
1384-
or req.is_finished_due_to_length):
1385-
block_id = self.kv_cache_manager.store_blocks_for_reuse(
1386-
req, True)
1387-
self.ctx_in_transmission_requests[
1388-
req.py_request_id] = (
1389-
(req, block_id,
1390-
self.ctx_in_transmission_counter))
1380+
if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver:
1381+
for req in self.previous_batch.sample_state.scheduled_requests.context_requests:
1382+
if req.is_context_only_request and (
1383+
req.is_context_finished
1384+
or req.is_finished_due_to_length):
1385+
block_id = self.kv_cache_manager.store_blocks_for_reuse(
1386+
req, True)
1387+
self.ctx_in_transmission_requests[
1388+
req.py_request_id] = (
1389+
(req, block_id,
1390+
self.ctx_in_transmission_counter))
13911391

1392+
if scheduled_batch.batch_size > 0:
13921393
if self.guided_decoder is not None:
13931394
# add_batch must be called again to have updated new tokens.
13941395
self.guided_decoder.add_batch(scheduled_batch)
@@ -1404,9 +1405,10 @@ def _executor_loop_overlap(self):
14041405
scheduled_batch.context_requests
14051406
) if self.kv_cache_transceiver else []
14061407

1407-
if self.previous_batch is not None:
1408-
self._process_previous_batch()
1408+
if self.previous_batch is not None:
1409+
self._process_previous_batch()
14091410

1411+
if scheduled_batch.batch_size > 0:
14101412
if self.enable_iter_perf_stats:
14111413
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
14121414
'num_ctx_tokens']
@@ -1879,7 +1881,17 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
18791881
request.context_chunk_size)
18801882
request.move_to_next_context_chunk()
18811883
if request.context_remaining_length == 0:
1882-
request.state = LlmRequestState.GENERATION_IN_PROGRESS
1884+
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
1885+
):
1886+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
1887+
else:
1888+
request.state = LlmRequestState.GENERATION_IN_PROGRESS
1889+
1890+
for request in scheduled_requests.generation_requests:
1891+
if request.state != LlmRequestState.GENERATION_COMPLETE:
1892+
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
1893+
):
1894+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
18831895

18841896
def _update_request_states_star_attention(
18851897
self, scheduled_requests: ScheduledRequests):

0 commit comments

Comments
 (0)