Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
f038a82
cache tokens in Python side to reduce pybind reading overhead
QiJune Jun 5, 2025
bba6719
refine
QiJune Jun 5, 2025
f826f5a
Merge branch 'main' into clean_prepare
QiJune Jun 5, 2025
8a3e92a
Merge branch 'main' into clean_prepare
QiJune Jun 5, 2025
44b12cc
Merge branch 'main' into clean_prepare
QiJune Jun 6, 2025
a8bd10d
Merge branch 'main' into clean_prepare
QiJune Jun 12, 2025
3c83644
pure Python LlmResponse
QiJune Jun 12, 2025
f0bb7c8
pure Python LlmResponse
QiJune Jun 12, 2025
60ca761
clean
QiJune Jun 12, 2025
5f7e9ea
fix
QiJune Jun 12, 2025
1b3a7b7
fix
QiJune Jun 12, 2025
df60073
fix
QiJune Jun 12, 2025
91c904e
fix
QiJune Jun 12, 2025
5d62cea
fix
QiJune Jun 12, 2025
50eb5b9
fix
QiJune Jun 12, 2025
ea2f8cc
fix
QiJune Jun 12, 2025
431926d
fix
QiJune Jun 12, 2025
3fb4d84
fix
QiJune Jun 12, 2025
04370ba
polish
QiJune Jun 12, 2025
acd09a4
expose createSerializedResult api
QiJune Jun 13, 2025
48a999c
fix
QiJune Jun 13, 2025
1e36d77
fix
QiJune Jun 13, 2025
4fd71bc
fix
QiJune Jun 13, 2025
1a4920b
fix
QiJune Jun 13, 2025
5e9888a
fix
QiJune Jun 13, 2025
9bfa834
fix
QiJune Jun 13, 2025
b37703c
fix
QiJune Jun 13, 2025
3e4844d
serialize result
QiJune Jun 13, 2025
e5e3873
fix
QiJune Jun 13, 2025
39a791a
fix
QiJune Jun 13, 2025
883b9f0
fix
QiJune Jun 13, 2025
ab84d66
fix
QiJune Jun 13, 2025
3c53e79
fix
QiJune Jun 13, 2025
144677b
fix
QiJune Jun 13, 2025
e5a6eff
use HIGHEST_PROTOCOL
QiJune Jun 14, 2025
bb13278
rebase
QiJune Jun 16, 2025
11ffdb8
fix ci
QiJune Jun 16, 2025
424dcb9
Merge branch 'main' into clean_prepare_2
QiJune Jun 16, 2025
3a66824
Merge branch 'main' into clean_prepare_2
QiJune Jun 16, 2025
a54beae
polish code
QiJune Jun 16, 2025
faa38f4
Merge branch 'main' into clean_prepare_2
QiJune Jun 16, 2025
df543f1
Merge branch 'main' into clean_prepare_2
QiJune Jun 16, 2025
a2e8ae1
Update internal cutlass commit. (#5228)
Tracin Jun 17, 2025
bb23483
test: add more pytorch cases in perf test (#5237)
ruodil Jun 17, 2025
546274d
fix ci (#5259)
QiJune Jun 17, 2025
eb0681e
Merge branch 'main' into clean_prepare_2
QiJune Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -2328,6 +2328,11 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
/// @return An optional Response
std::optional<executor::Response> createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0);

std::optional<executor::Result> createResult(bool useFastLogits = false, int32_t mpiWorldRank = 0);

void createSerializedResult(
std::vector<char>& serializedResult, bool& isFinal, bool useFastLogits = false, int32_t mpiWorldRank = 0);

void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, SizeType32 vocabSizePadded,
std::optional<SizeType32> maxEncoderInputLen = std::nullopt, bool enableKVCacheReuse = false);

Expand Down
35 changes: 29 additions & 6 deletions cpp/tensorrt_llm/batch_manager/llmRequest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/

#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/kernels/beamSearchKernels.h"

namespace tensorrt_llm::batch_manager
Expand All @@ -39,8 +40,34 @@ runtime::SizeType32 GenericLlmRequest<TTensor, TStream>::getBeamWidthByIter(bool

template class GenericLlmRequest<runtime::ITensor::SharedPtr>;

/// Note that there is some dependency on the order of operations in this method. Modify with care!
std::optional<executor::Response> LlmRequest::createResponse(bool useFastLogits, int32_t mpiWorldRank)
{
auto requestId = isChild() ? mParentRequestId : mRequestId;
auto result = createResult(useFastLogits, mpiWorldRank);
if (result.has_value())
{
return executor::Response(requestId, result.value(), mClientId);
}
return std::nullopt;
}

void LlmRequest::createSerializedResult(
std::vector<char>& serializedResult, bool& isFinal, bool useFastLogits, int32_t mpiWorldRank)
{
auto result = createResult(useFastLogits, mpiWorldRank);
if (result.has_value())
{
std::ostringstream oStream;
executor::serialize_utils::serialize(result.value(), oStream);
auto str = oStream.str();
serializedResult.resize(str.size());
std::copy(str.begin(), str.end(), serializedResult.begin());
isFinal = result.value().isFinal;
}
}

/// Note that there is some dependency on the order of operations in this method. Modify with care!
std::optional<executor::Result> LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank)
{
TLLM_CHECK(!isDisaggContextCompleteState());
if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)))
Expand Down Expand Up @@ -192,11 +219,7 @@ std::optional<executor::Response> LlmRequest::createResponse(bool useFastLogits,

// Update position of last sent response
setMaxSentTokenLen(maxNbTokens);

auto requestId = isChild() ? mParentRequestId : mRequestId;
auto response = executor::Response(requestId, std::move(result), mClientId);

return response;
return result;
}

void LlmRequest::validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
a1180829a0d8fe772ff37934b72573bb41671e7ed76dfa3bd5cd449348b9683a libtensorrt_llm_internal_cutlass_kernels_static.a
commit 98a790a71a0734881180e434b8c4271ae0f21f34
commit c767347ff934578193ee4bad58ba3b9398046245
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
e7130e36217c1df0d281788fc87764945d9c308bef11ad61b3b1a49c7d41c8af libtensorrt_llm_internal_cutlass_kernels_static.a
commit 98a790a71a0734881180e434b8c4271ae0f21f34
commit c767347ff934578193ee4bad58ba3b9398046245
11 changes: 11 additions & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
#include <tuple>

namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
Expand Down Expand Up @@ -360,6 +361,16 @@ void initBindings(pybind11::module_& m)
py::arg("enable_kv_cache_reuse") = false)
.def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false,
py::arg("mpi_world_rank") = 0)
.def("create_serialized_result",
[](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0)
{
std::vector<char> serialized_result;
bool is_final = false;
self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank);
return std::make_tuple(py::bytes(serialized_result.data(), serialized_result.size()), is_final);
})
.def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, py::arg("manager"))
.def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager"))
.def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason"));
Expand Down
9 changes: 9 additions & 0 deletions cpp/tensorrt_llm/pybind/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/serializeUtils.h"
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/cudaStream.h"
Expand All @@ -29,6 +30,7 @@
#include <pybind11/operators.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <sstream>

#include <optional>
#include <vector>
Expand Down Expand Up @@ -775,6 +777,13 @@ void initRequestBindings(pybind11::module_& m)
.def_readwrite("context_phase_params", &tle::Result::contextPhaseParams)
.def(py::pickle(resultGetstate, resultSetstate));

m.def("deserialize_result",
[](std::string& x)
{
std::istringstream is(x);
return tle::serialize_utils::deserialize<tle::Result>(is);
});

auto responseGetstate = [](tle::Response const& self)
{ return py::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); };

Expand Down
65 changes: 34 additions & 31 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import List, Optional
from dataclasses import dataclass
from typing import List, Optional, Union

import torch

Expand Down Expand Up @@ -205,43 +206,36 @@ class LlmResult:
py_result_properties = frozenset(
('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs'))

def __init__(self, result: tensorrt_llm.bindings.executor.Result,
py_result: PyResult):
def __init__(self,
result: Union[bytes, tensorrt_llm.bindings.executor.Result],
py_result: PyResult,
is_final: bool = False):
self._result = result
self._py_result = py_result
self.is_final = is_final

def __getattr__(self, item):
if item in self.py_result_properties:
return getattr(self._py_result, item)
return getattr(self._result, item)

if item == 'is_final':
return object.__getattribute__(self, 'is_final')
result = object.__getattribute__(self, '_result')
return getattr(result, item)

class LlmResponse:
"""LlmResponse wraps `bindings.executor.Response` but detour some features to Python implementation"""
def deserialize(self):
self._result = tensorrt_llm.bindings.executor.deserialize_result(
self._result)

def __init__(self, response: tensorrt_llm.bindings.executor.Response,
py_result: PyResult):
self._response = response
self._py_result = py_result

def __getstate__(self):
return self._response, self._py_result

def __setstate__(self, state):
self._response, self._py_result = state

@property
def result(self) -> tensorrt_llm.bindings.executor.Result:
return LlmResult(
self._response.result,
self._py_result) # LlmResult masquerades bindings.executor.Result

@property
def _is_llm_response(self) -> bool:
return True
@dataclass
class LlmResponse:
request_id: int
error_msg: Optional[str] = None
result: Optional[LlmResult] = None
client_id: Optional[int] = None

def __getattr__(self, item):
return getattr(self._response, item)
def has_error(self):
return self.error_msg is not None


class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
Expand Down Expand Up @@ -273,6 +267,7 @@ def __init__(
**kwargs)
self.py_client_id = client_id
self.py_request_id = self.request_id
self.py_llm_request_type = self.llm_request_type
self.py_end_id = self.end_id
self.py_prompt_len = self.prompt_len
self.py_orig_prompt_len = self.orig_prompt_len
Expand All @@ -287,6 +282,8 @@ def __init__(
self.is_cuda_graph_dummy = False
self.py_lora_task_layer_module_configs = None

self.py_tokens = super().get_tokens()

self.py_return_log_probs = return_log_probs
self.py_return_context_logits = return_context_logits
self.py_return_generation_logits = return_generation_logits
Expand All @@ -302,13 +299,19 @@ def __init__(
return_generation_logits,
exclude_last_generation_logits)

def is_generation_only_request(self):
return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY

def create_response(
self,
use_fast_logits=False,
mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None:
response = super().create_response(use_fast_logits, mpi_world_rank)
return LlmResponse(response,
self.py_result) if response is not None else None
result, is_final = super().create_serialized_result(
use_fast_logits, mpi_world_rank)
return LlmResponse(
request_id=self.py_request_id,
result=LlmResult(result, self.py_result, is_final),
client_id=self.py_client_id) if len(result) > 0 else None

@property
def is_dummy(self):
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,7 +1172,7 @@ def _prepare_tp_inputs(
gather_ids.append(len(input_ids) - 1)
sequence_lengths.append(len(prompt_tokens))
prompt_lengths.append(len(prompt_tokens))
past_seen_token_num = request.context_current_position
past_seen_token_num = begin_compute
num_cached_tokens_per_seq.append(past_seen_token_num)
multimodal_embedding = request.multimodal_embedding
if multimodal_embedding is not None:
Expand Down
24 changes: 13 additions & 11 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@

from ..distributed import Distributed
from .kv_cache_transceiver import KvCacheTransceiver
from .llm_request import (ExecutorRequest, ExecutorResponse, LlmRequest,
LlmRequestState, executor_request_to_llm_request)
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
LlmResponse, executor_request_to_llm_request)
from .model_engine import ModelEngine
from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler
from .scheduler import ScheduledRequests
Expand Down Expand Up @@ -328,14 +328,14 @@ def await_responses(
self,
id: Optional[Union[List[int], int]] = None,
timeout: Optional[datetime.timedelta] = None,
) -> Union[List[List[ExecutorResponse]], List[ExecutorResponse]]:
) -> Union[List[List[LlmResponse]], List[LlmResponse]]:
"""
Await for ready responses
Args:
id (Optional[Union[List[int], int]]): Request id
timeout (Optional[datetime.timedelta]): The maximum time to wait for new responses
Returns:
Union[List[tensorrt_llm.bindings.executor.Response], List[List[tensorrt_llm.bindings.executor.Response]]]: Responses
Union[List[LlmResponse], List[List[LlmResponse]]]: Responses
"""
timeout = timeout.total_seconds() if timeout is not None else None
if id is None:
Expand Down Expand Up @@ -1928,8 +1928,10 @@ def _handle_errors(self, error_msg: Optional[str] = None):
req_id = request.py_request_id
request.state = LlmRequestState.GENERATION_COMPLETE
self._terminate_request(request)
error_responses[req_id] = ExecutorResponse(
req_id, error_msg, client_id=request.py_client_id)
error_responses[req_id] = LlmResponse(
request_id=req_id,
error_msg=error_msg,
client_id=request.py_client_id)
self.active_requests.clear()
self._enqueue_responses(error_responses)

Expand Down Expand Up @@ -1973,7 +1975,7 @@ def _handle_cancelled_requests(self):
self._enqueue_responses(cancelled_responses)

@nvtx_range("_enqueue_responses")
def _enqueue_responses(self, responses: Dict[int, ExecutorResponse]):
def _enqueue_responses(self, responses: Dict[int, LlmResponse]):
if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses:
return

Expand Down Expand Up @@ -2030,7 +2032,7 @@ def _handle_responses(self):
requests_to_terminate.append(request)
continue

if request.is_generation_only_request:
if request.is_generation_only_request():
# If request is in transmission, so we don't need to emit a response
# Also, for the first iteration with overlap, we should skip since first
# token has already been emitted previously
Expand All @@ -2042,7 +2044,7 @@ def _handle_responses(self):

request.draft_tokens = request.py_draft_tokens
request.decoding_iter = request.py_decoding_iter
response: Response = request.create_response(False, self.dist.rank)
response = request.create_response(False, self.dist.rank)
request_done = False
if response:
request_done = response.result.is_final
Expand All @@ -2069,7 +2071,7 @@ def _terminate_ctx_finished_requests(self):

def _await_any_response(self,
timeout: Optional[float] = None
) -> List[ExecutorResponse]:
) -> List[LlmResponse]:

def any_responses_ready():
return len(self.responses) > 0 or self.is_shutdown
Expand All @@ -2086,7 +2088,7 @@ def any_responses_ready():
def _await_single_response(
self,
id: int,
timeout: Optional[float] = None) -> List[ExecutorResponse]:
timeout: Optional[float] = None) -> List[LlmResponse]:
with self.response_cv:

def key_has_response():
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,9 @@ def _handle_response(self,
handler(response.error_msg)

response_result = response.result
if hasattr(response_result, "_result"):
response_result.deserialize()

self._done = response_result.is_final
context_phase_params = response_result.context_phase_params
self.decoding_iter = response_result.decoding_iter
Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/executor/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# pickle is not secure, but but this whole file is a wrapper to make it
# possible to mitigate the primary risk of code injection via pickle.
import pickle # nosec B403
from functools import partial

# This is an example class (white list) to showcase how to guard serialization with approved classes.
# If a class is needed routinely it should be added into the whitelist. If it is only needed in a single instance
Expand Down Expand Up @@ -53,8 +54,8 @@ def find_class(self, module, name):
# dump and dumps are just aliases because the serucity controls are on the deserialization
# side. However they are included here so that in the future if a more secure serialization
# soliton is identified, it can be added with less impact to the rest of the application.
dump = pickle.dump # nosec B301
dumps = pickle.dumps # nosec B301
dump = partial(pickle.dump, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301
dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301


def load(file,
Expand Down
4 changes: 1 addition & 3 deletions tensorrt_llm/executor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from strenum import StrEnum

from tensorrt_llm._utils import mpi_rank
from tensorrt_llm.bindings.executor import Response
from tensorrt_llm.llmapi.utils import print_colored_debug

from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession,
Expand Down Expand Up @@ -144,5 +143,4 @@ class WorkerCommIpcAddrs(NamedTuple):


def is_llm_response(instance):
return isinstance(instance, Response) or \
(hasattr(instance, '_is_llm_response') and instance._is_llm_response)
return hasattr(instance, "result")
2 changes: 1 addition & 1 deletion tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def temp_extra_llm_api_options_file(request):
"enable_block_reuse": False,
"max_tokens": 40000
},
"_num_postprocess_workers": 2,
"num_postprocess_workers": 2,
}

pytorch_backend_config = {}
Expand Down
Loading