diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 086dc2bf4a5..d71b6e89f6a 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -2328,6 +2328,11 @@ class LlmRequest : public GenericLlmRequest /// @return An optional Response std::optional createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0); + std::optional createResult(bool useFastLogits = false, int32_t mpiWorldRank = 0); + + void createSerializedResult( + std::vector& serializedResult, bool& isFinal, bool useFastLogits = false, int32_t mpiWorldRank = 0); + void validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, SizeType32 vocabSizePadded, std::optional maxEncoderInputLen = std::nullopt, bool enableKVCacheReuse = false); diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index 6fc7051ad7e..433f349b07d 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -16,6 +16,7 @@ */ #include "tensorrt_llm/batch_manager/llmRequest.h" +#include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/kernels/beamSearchKernels.h" namespace tensorrt_llm::batch_manager @@ -39,8 +40,34 @@ runtime::SizeType32 GenericLlmRequest::getBeamWidthByIter(bool template class GenericLlmRequest; -/// Note that there is some dependency on the order of operations in this method. Modify with care! std::optional LlmRequest::createResponse(bool useFastLogits, int32_t mpiWorldRank) +{ + auto requestId = isChild() ? mParentRequestId : mRequestId; + auto result = createResult(useFastLogits, mpiWorldRank); + if (result.has_value()) + { + return executor::Response(requestId, result.value(), mClientId); + } + return std::nullopt; +} + +void LlmRequest::createSerializedResult( + std::vector& serializedResult, bool& isFinal, bool useFastLogits, int32_t mpiWorldRank) +{ + auto result = createResult(useFastLogits, mpiWorldRank); + if (result.has_value()) + { + std::ostringstream oStream; + executor::serialize_utils::serialize(result.value(), oStream); + auto str = oStream.str(); + serializedResult.resize(str.size()); + std::copy(str.begin(), str.end(), serializedResult.begin()); + isFinal = result.value().isFinal; + } +} + +/// Note that there is some dependency on the order of operations in this method. Modify with care! +std::optional LlmRequest::createResult(bool useFastLogits, int32_t mpiWorldRank) { TLLM_CHECK(!isDisaggContextCompleteState()); if (!(isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS))) @@ -192,11 +219,7 @@ std::optional LlmRequest::createResponse(bool useFastLogits, // Update position of last sent response setMaxSentTokenLen(maxNbTokens); - - auto requestId = isChild() ? mParentRequestId : mRequestId; - auto response = executor::Response(requestId, std::move(result), mClientId); - - return response; + return result; } void LlmRequest::validate(SizeType32 maxInputLen, SizeType32 maxSequenceLen, SizeType32 maxDraftLen, diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt index 5c9c419e0fd..240af5e5cfa 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/aarch64-linux-gnu/version.txt @@ -1,2 +1,2 @@ a1180829a0d8fe772ff37934b72573bb41671e7ed76dfa3bd5cd449348b9683a libtensorrt_llm_internal_cutlass_kernels_static.a -commit 98a790a71a0734881180e434b8c4271ae0f21f34 +commit c767347ff934578193ee4bad58ba3b9398046245 diff --git a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt index 33171e56a12..d48c4297480 100644 --- a/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt +++ b/cpp/tensorrt_llm/kernels/internal_cutlass_kernels/x86_64-linux-gnu/version.txt @@ -1,2 +1,2 @@ e7130e36217c1df0d281788fc87764945d9c308bef11ad61b3b1a49c7d41c8af libtensorrt_llm_internal_cutlass_kernels_static.a -commit 98a790a71a0734881180e434b8c4271ae0f21f34 +commit c767347ff934578193ee4bad58ba3b9398046245 diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 35f32a3b128..a6fd8b8e9d1 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -36,6 +36,7 @@ #include #include #include +#include namespace py = pybind11; namespace tb = tensorrt_llm::batch_manager; @@ -360,6 +361,16 @@ void initBindings(pybind11::module_& m) py::arg("enable_kv_cache_reuse") = false) .def("create_response", &tb::LlmRequest::createResponse, py::arg("use_fast_logits") = false, py::arg("mpi_world_rank") = 0) + .def("create_result", &tb::LlmRequest::createResult, py::arg("use_fast_logits") = false, + py::arg("mpi_world_rank") = 0) + .def("create_serialized_result", + [](tb::LlmRequest& self, bool use_fast_logits = false, int mpi_world_rank = 0) + { + std::vector serialized_result; + bool is_final = false; + self.createSerializedResult(serialized_result, is_final, use_fast_logits, mpi_world_rank); + return std::make_tuple(py::bytes(serialized_result.data(), serialized_result.size()), is_final); + }) .def("move_prompt_embedding_table_to_gpu", &tb::LlmRequest::movePromptEmbeddingTableToGpu, py::arg("manager")) .def("move_lora_weights_to_gpu", &tb::LlmRequest::moveLoraWeightsToGpu, py::arg("manager")) .def("finish_by_reason", &tb::LlmRequest::finishByReason, py::arg("finish_reason")); diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 85fb83ce272..e43ad80ce14 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -19,6 +19,7 @@ #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/logger.h" #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/executor/tensor.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/cudaStream.h" @@ -29,6 +30,7 @@ #include #include #include +#include #include #include @@ -775,6 +777,13 @@ void initRequestBindings(pybind11::module_& m) .def_readwrite("context_phase_params", &tle::Result::contextPhaseParams) .def(py::pickle(resultGetstate, resultSetstate)); + m.def("deserialize_result", + [](std::string& x) + { + std::istringstream is(x); + return tle::serialize_utils::deserialize(is); + }); + auto responseGetstate = [](tle::Response const& self) { return py::make_tuple(self.getRequestId(), self.getResult(), self.getClientId()); }; diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 878b4abd216..e792be72ede 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -1,4 +1,5 @@ -from typing import List, Optional +from dataclasses import dataclass +from typing import List, Optional, Union import torch @@ -205,43 +206,36 @@ class LlmResult: py_result_properties = frozenset( ('context_logits', 'generation_logits', 'log_probs', 'cum_log_probs')) - def __init__(self, result: tensorrt_llm.bindings.executor.Result, - py_result: PyResult): + def __init__(self, + result: Union[bytes, tensorrt_llm.bindings.executor.Result], + py_result: PyResult, + is_final: bool = False): self._result = result self._py_result = py_result + self.is_final = is_final def __getattr__(self, item): if item in self.py_result_properties: return getattr(self._py_result, item) - return getattr(self._result, item) - + if item == 'is_final': + return object.__getattribute__(self, 'is_final') + result = object.__getattribute__(self, '_result') + return getattr(result, item) -class LlmResponse: - """LlmResponse wraps `bindings.executor.Response` but detour some features to Python implementation""" + def deserialize(self): + self._result = tensorrt_llm.bindings.executor.deserialize_result( + self._result) - def __init__(self, response: tensorrt_llm.bindings.executor.Response, - py_result: PyResult): - self._response = response - self._py_result = py_result - def __getstate__(self): - return self._response, self._py_result - - def __setstate__(self, state): - self._response, self._py_result = state - - @property - def result(self) -> tensorrt_llm.bindings.executor.Result: - return LlmResult( - self._response.result, - self._py_result) # LlmResult masquerades bindings.executor.Result - - @property - def _is_llm_response(self) -> bool: - return True +@dataclass +class LlmResponse: + request_id: int + error_msg: Optional[str] = None + result: Optional[LlmResult] = None + client_id: Optional[int] = None - def __getattr__(self, item): - return getattr(self._response, item) + def has_error(self): + return self.error_msg is not None class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest): @@ -273,6 +267,7 @@ def __init__( **kwargs) self.py_client_id = client_id self.py_request_id = self.request_id + self.py_llm_request_type = self.llm_request_type self.py_end_id = self.end_id self.py_prompt_len = self.prompt_len self.py_orig_prompt_len = self.orig_prompt_len @@ -287,6 +282,8 @@ def __init__( self.is_cuda_graph_dummy = False self.py_lora_task_layer_module_configs = None + self.py_tokens = super().get_tokens() + self.py_return_log_probs = return_log_probs self.py_return_context_logits = return_context_logits self.py_return_generation_logits = return_generation_logits @@ -302,13 +299,19 @@ def __init__( return_generation_logits, exclude_last_generation_logits) + def is_generation_only_request(self): + return self.py_llm_request_type == LlmRequestType.LLMREQUEST_TYPE_GENERATION_ONLY + def create_response( self, use_fast_logits=False, mpi_world_rank=0) -> tensorrt_llm.bindings.executor.Response | None: - response = super().create_response(use_fast_logits, mpi_world_rank) - return LlmResponse(response, - self.py_result) if response is not None else None + result, is_final = super().create_serialized_result( + use_fast_logits, mpi_world_rank) + return LlmResponse( + request_id=self.py_request_id, + result=LlmResult(result, self.py_result, is_final), + client_id=self.py_client_id) if len(result) > 0 else None @property def is_dummy(self): diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e2c1477c5ea..e35dc1a189c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1172,7 +1172,7 @@ def _prepare_tp_inputs( gather_ids.append(len(input_ids) - 1) sequence_lengths.append(len(prompt_tokens)) prompt_lengths.append(len(prompt_tokens)) - past_seen_token_num = request.context_current_position + past_seen_token_num = begin_compute num_cached_tokens_per_seq.append(past_seen_token_num) multimodal_embedding = request.multimodal_embedding if multimodal_embedding is not None: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 6456b874551..e6431f52143 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -30,8 +30,8 @@ from ..distributed import Distributed from .kv_cache_transceiver import KvCacheTransceiver -from .llm_request import (ExecutorRequest, ExecutorResponse, LlmRequest, - LlmRequestState, executor_request_to_llm_request) +from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, + LlmResponse, executor_request_to_llm_request) from .model_engine import ModelEngine from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler from .scheduler import ScheduledRequests @@ -328,14 +328,14 @@ def await_responses( self, id: Optional[Union[List[int], int]] = None, timeout: Optional[datetime.timedelta] = None, - ) -> Union[List[List[ExecutorResponse]], List[ExecutorResponse]]: + ) -> Union[List[List[LlmResponse]], List[LlmResponse]]: """ Await for ready responses Args: id (Optional[Union[List[int], int]]): Request id timeout (Optional[datetime.timedelta]): The maximum time to wait for new responses Returns: - Union[List[tensorrt_llm.bindings.executor.Response], List[List[tensorrt_llm.bindings.executor.Response]]]: Responses + Union[List[LlmResponse], List[List[LlmResponse]]]: Responses """ timeout = timeout.total_seconds() if timeout is not None else None if id is None: @@ -1928,8 +1928,10 @@ def _handle_errors(self, error_msg: Optional[str] = None): req_id = request.py_request_id request.state = LlmRequestState.GENERATION_COMPLETE self._terminate_request(request) - error_responses[req_id] = ExecutorResponse( - req_id, error_msg, client_id=request.py_client_id) + error_responses[req_id] = LlmResponse( + request_id=req_id, + error_msg=error_msg, + client_id=request.py_client_id) self.active_requests.clear() self._enqueue_responses(error_responses) @@ -1973,7 +1975,7 @@ def _handle_cancelled_requests(self): self._enqueue_responses(cancelled_responses) @nvtx_range("_enqueue_responses") - def _enqueue_responses(self, responses: Dict[int, ExecutorResponse]): + def _enqueue_responses(self, responses: Dict[int, LlmResponse]): if 0 not in self.dist.mapping.tp_group and not self.gather_all_responses: return @@ -2030,7 +2032,7 @@ def _handle_responses(self): requests_to_terminate.append(request) continue - if request.is_generation_only_request: + if request.is_generation_only_request(): # If request is in transmission, so we don't need to emit a response # Also, for the first iteration with overlap, we should skip since first # token has already been emitted previously @@ -2042,7 +2044,7 @@ def _handle_responses(self): request.draft_tokens = request.py_draft_tokens request.decoding_iter = request.py_decoding_iter - response: Response = request.create_response(False, self.dist.rank) + response = request.create_response(False, self.dist.rank) request_done = False if response: request_done = response.result.is_final @@ -2069,7 +2071,7 @@ def _terminate_ctx_finished_requests(self): def _await_any_response(self, timeout: Optional[float] = None - ) -> List[ExecutorResponse]: + ) -> List[LlmResponse]: def any_responses_ready(): return len(self.responses) > 0 or self.is_shutdown @@ -2086,7 +2088,7 @@ def any_responses_ready(): def _await_single_response( self, id: int, - timeout: Optional[float] = None) -> List[ExecutorResponse]: + timeout: Optional[float] = None) -> List[LlmResponse]: with self.response_cv: def key_has_response(): diff --git a/tensorrt_llm/executor/result.py b/tensorrt_llm/executor/result.py index abd72cb66b5..2716f80d167 100644 --- a/tensorrt_llm/executor/result.py +++ b/tensorrt_llm/executor/result.py @@ -294,6 +294,9 @@ def _handle_response(self, handler(response.error_msg) response_result = response.result + if hasattr(response_result, "_result"): + response_result.deserialize() + self._done = response_result.is_final context_phase_params = response_result.context_phase_params self.decoding_iter = response_result.decoding_iter diff --git a/tensorrt_llm/executor/serialization.py b/tensorrt_llm/executor/serialization.py index 3f6cf4a026c..b295e9f0b07 100644 --- a/tensorrt_llm/executor/serialization.py +++ b/tensorrt_llm/executor/serialization.py @@ -2,6 +2,7 @@ # pickle is not secure, but but this whole file is a wrapper to make it # possible to mitigate the primary risk of code injection via pickle. import pickle # nosec B403 +from functools import partial # This is an example class (white list) to showcase how to guard serialization with approved classes. # If a class is needed routinely it should be added into the whitelist. If it is only needed in a single instance @@ -53,8 +54,8 @@ def find_class(self, module, name): # dump and dumps are just aliases because the serucity controls are on the deserialization # side. However they are included here so that in the future if a more secure serialization # soliton is identified, it can be added with less impact to the rest of the application. -dump = pickle.dump # nosec B301 -dumps = pickle.dumps # nosec B301 +dump = partial(pickle.dump, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301 +dumps = partial(pickle.dumps, protocol=pickle.HIGHEST_PROTOCOL) # nosec B301 def load(file, diff --git a/tensorrt_llm/executor/utils.py b/tensorrt_llm/executor/utils.py index 238a6fddf66..fd4cd8444ec 100644 --- a/tensorrt_llm/executor/utils.py +++ b/tensorrt_llm/executor/utils.py @@ -8,7 +8,6 @@ from strenum import StrEnum from tensorrt_llm._utils import mpi_rank -from tensorrt_llm.bindings.executor import Response from tensorrt_llm.llmapi.utils import print_colored_debug from ..llmapi.mpi_session import (MpiCommSession, MpiPoolSession, MpiSession, @@ -144,5 +143,4 @@ class WorkerCommIpcAddrs(NamedTuple): def is_llm_response(instance): - return isinstance(instance, Response) or \ - (hasattr(instance, '_is_llm_response') and instance._is_llm_response) + return hasattr(instance, "result") diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 94a45937e5e..7b8e122c120 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -677,7 +677,7 @@ def temp_extra_llm_api_options_file(request): "enable_block_reuse": False, "max_tokens": 40000 }, - "_num_postprocess_workers": 2, + "num_postprocess_workers": 2, } pytorch_backend_config = {} diff --git a/tests/integration/test_lists/qa/trt_llm_release_perf_cluster_test.yml b/tests/integration/test_lists/qa/trt_llm_release_perf_cluster_test.yml index 367cb38fb56..6e2f38f68b9 100644 --- a/tests/integration/test_lists/qa/trt_llm_release_perf_cluster_test.yml +++ b/tests/integration/test_lists/qa/trt_llm_release_perf_cluster_test.yml @@ -10,6 +10,9 @@ trt_llm_release_perf_cluster_test: - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:128,128-quant:fp8] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:512,32-quant:fp8] - perf/test_perf.py::test_perf[llama_v3_8b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3_8b_instruct_fp8-bench-pytorch-float8-input_output_len:2000,500] + - perf/test_perf.py::test_perf[llama_v3_8b_instruct_fp8-bench-pytorch-streaming-float8-input_output_len:2000,500] + - perf/test_perf.py::test_perf[llama_v3_8b_instruct_fp8-bench-pytorch-float8-input_output_len:500,2000] - perf/test_perf.py::test_perf[t5-bench-float16-input_output_len:128,20] - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:1000,1000-quant:fp8] - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:500,2000-quant:fp8] diff --git a/tests/integration/test_lists/qa/trt_llm_release_perf_sanity_test.yml b/tests/integration/test_lists/qa/trt_llm_release_perf_sanity_test.yml index 1e7c5f75d26..5e1766b5eb0 100644 --- a/tests/integration/test_lists/qa/trt_llm_release_perf_sanity_test.yml +++ b/tests/integration/test_lists/qa/trt_llm_release_perf_sanity_test.yml @@ -59,6 +59,10 @@ trt_llm_release_perf_sanity_test: - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-cppmanager-exe-static_batching-plugin_ifb-float16-bs:8+64-input_output_len:128,128+512,32] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-cppmanager-exe-plugin_ifb-bfloat16-gwp:0.0-input_output_len:128,128+512,32] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:512,32] - perf/test_perf.py::test_perf[qwen2_7b_instruct-bench-float16-input_output_len:128,128] # FP8 specific tests @@ -75,8 +79,8 @@ trt_llm_release_perf_sanity_test: tests: - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:128,128-quant:fp8] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32-quant:fp8] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-pytorch-bfloat16-input_output_len:128,128-quant:fp8] - - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-pytorch-bfloat16-input_output_len:512,32-quant:fp8] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct_fp8-bench-pytorch-float8-input_output_len:512,32] # Tests for systems with 2+ GPUs - condition: @@ -98,6 +102,7 @@ trt_llm_release_perf_sanity_test: - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-maxbs:256-input_output_len:128,128-gpu:2] - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-streaming-bfloat16-input_output_len:128,128-gpus:2] - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-maxbs:1-input_output_len:128,128-reqs:10-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:128,128-gpus:2] # FP8 tests for systems with 2+ GPUs - condition: @@ -118,6 +123,7 @@ trt_llm_release_perf_sanity_test: - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-quant:fp8-gpus:2] + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-pytorch-float16-input_output_len:128,128-quant:fp8-gpus:2] # Tests for systems with 2+ GPUs and high memory - condition: @@ -151,6 +157,7 @@ trt_llm_release_perf_sanity_test: - '*h20*' tests: - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:128,128-reqs:10-gpus:4] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:128,128-reqs:10-gpus:4] - perf/test_perf.py::test_perf[qwen_14b_chat-cppmanager-ootb_except_mha-float16-input_output_len:128,128-gpus:4] - perf/test_perf.py::test_perf[starcoder_15.5b-cppmanager-exe-plugin_ifb-float16-maxbs:1-input_output_len:512,200-reqs:10-gpus:4] @@ -174,6 +181,7 @@ trt_llm_release_perf_sanity_test: - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:200,2000-reqs:10-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:200,2000-reqs:10-gpus:8] - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:1-input_output_len:128,128-reqs:10-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-pytorch-bfloat16-maxbs:1-input_output_len:2000,200-reqs:10-gpus:8] # FP8 tests for systems with 8+ GPUs @@ -194,3 +202,19 @@ trt_llm_release_perf_sanity_test: - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:128,128-quant:fp8-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-maxbs:1-input_output_len:512,32-quant:fp8-gpus:8] - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct_fp8-bench-pytorch-float8-input_output_len:128,128-gpus:8] + +- condition: + terms: + supports_fp8: true + ranges: + system_gpu_count: + gte: 8 + wildcards: + gpu: + - '*h100*' + - '*h200*' + - '*h20*' + + tests: + - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:128,128] + - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-streaming-pytorch-float8-input_output_len:128,128] diff --git a/tests/integration/test_lists/qa/trt_llm_release_perf_test.yml b/tests/integration/test_lists/qa/trt_llm_release_perf_test.yml index ce6734cfe3b..0b4f94ff591 100644 --- a/tests/integration/test_lists/qa/trt_llm_release_perf_test.yml +++ b/tests/integration/test_lists/qa/trt_llm_release_perf_test.yml @@ -30,6 +30,8 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.1_8b_instruct-bench-bfloat16-input_output_len:512,32] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:128,128] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:128,128] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-bfloat16-input_output_len:512,32] + - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-pytorch-streaming-bfloat16-input_output_len:128,128] - perf/test_perf.py::test_perf[llama_v3.1_8b-bench-bfloat16-input_output_len:512,32] - perf/test_perf.py::test_perf[qwen2_7b_instruct-bench-float16-input_output_len:128,128] - perf/test_perf.py::test_perf[starcoder2_3b-bench-pytorch-float16-input_output_len:512,200] @@ -112,7 +114,10 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:1000,1000-reqs:500-con:250] - perf/test_perf.py::test_perf[llama_v3.1_nemotron_nano_8b-bench-pytorch-bfloat16-maxbs:512-input_output_len:20000,2000-reqs:500-con:250] #need to extend context token to 20000 for l40s, timeout for h20, a100 # deepseek_v3_lite_fp8 - - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:128,128] # not supported on L20, L40S + - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:128,128] + - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:2000,500] + - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-streaming-float8-input_output_len:2000,500] + - perf/test_perf.py::test_perf[deepseek_v3_lite_fp8-bench-pytorch-float8-input_output_len:500,2000] # FP8 specific tests - condition: @@ -192,12 +197,12 @@ trt_llm_release_perf_test: - '*a100*' - '*h20*' tests: - - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-tp:2-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-loras:8-tp:2-gpus:2] - perf/test_perf.py::test_perf[llama_v3.1_70b-bench-bfloat16-input_output_len:1024,1024-tp:2-gpus:2] - perf/test_perf.py::test_perf[llama_70b_sq_per_tensor-cppmanager-exe-plugin_ifb-float16-input_output_len:128,128+512,32-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b-cppmanager-exe-plugin_ifb-float16-mp-input_output_len:128,128+512,32-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-gpus:2] + - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-pytorch-float16-input_output_len:128,128-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-streaming-float16-input_output_len:128,128-gpus:2] # FP8 specific tests @@ -218,10 +223,13 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:128,128-quant:fp8-gpus:2] - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:512,32-quant:fp8-gpus:2] - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-bfloat16-input_output_len:512,200-quant:fp8-gpus:2] + - perf/test_perf.py::test_perf[llama_v3.2_1b-bench-pytorch-bfloat16-input_output_len:512,200-quant:fp8-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:128,128-quant:fp8-gpus:2] - perf/test_perf.py::test_perf[mixtral_8x7b_v0.1-bench-float16-input_output_len:512,32-quant:fp8-gpus:2] - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:1000,1000-quant:fp8-tp:2] + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-pytorch-float16-input_output_len:1000,1000-quant:fp8-tp:2] - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-float16-input_output_len:500,2000-quant:fp8-tp:2] + - perf/test_perf.py::test_perf[mistral_7b_v0.1-bench-pytorch-float16-input_output_len:500,2000-quant:fp8-tp:2] - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-float16-maxbs:128-input_output_len:1000,1000-quant:fp8-tp:2] - perf/test_perf.py::test_perf[phi_3_mini_128k_instruct-bench-float16-maxbs:128-input_output_len:500,2000-quant:fp8-tp:2] @@ -327,7 +335,8 @@ trt_llm_release_perf_test: - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-pytorch-bfloat16-input_output_len:200,2000-reqs:64-con:200-gpus:8] - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:200,2000-reqs:8-con:1-gpus:8] # timeout for h20, move to l2 test - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] - - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-streaming-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-pytorch-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] + - perf/test_perf.py::test_perf[llama_v3.1_70b_instruct-bench-pytorch-streaming-bfloat16-input_output_len:2000,200-reqs:64-gpus:8] - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-input_output_len:128,128-gpus:8] - perf/test_perf.py::test_perf[llama_v3.3_70b_instruct-bench-bfloat16-maxbs:16-input_output_len:5000,500-reqs:64-con:250-gpus:8] - perf/test_perf.py::test_perf[gpt_20b-bench-float16-maxbs:8-input_output_len:128,128-reqs:80-gpus:8] @@ -375,6 +384,8 @@ trt_llm_release_perf_test: tests: - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:32-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-streaming-float8-maxbs:32-input_output_len:128,128-ep:8-tp:8-gpus:8] + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:1-input_output_len:1000,2000-reqs:10-con:1-ep:4-tp:8-gpus:8] TIMEOUT(40)#min latency test + - perf/test_perf.py::test_perf[deepseek_r1_fp8-bench-pytorch-float8-maxbs:128-maxnt:1127-input_output_len:1000,2000-reqs:5120-con:1024-ep:8-tp:8-gpus:8] TIMEOUT(80) #max throughput test - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] TIMEOUT(20) - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-streaming-bfloat16-input_output_len:128,128-ep:8-tp:8-gpus:8] - perf/test_perf.py::test_perf[llama_v4_scout_17b_16e_instruct-bench-pytorch-bfloat16-input_output_len:500,2000-ep:8-tp:8-gpus:8] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 4a76a98b85c..0c22070850f 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -241,7 +241,6 @@ examples/test_qwen.py::test_llm_qwen_moe_multi_gpu_summary[qwen2_57b_a14b-tp2pp2 examples/test_mixtral.py::test_llm_mixtral_moe_plugin_fp8_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora] SKIP (https://nvbugs/5064768) llmapi/test_llm_e2e.py::test_llmapi_build_command_parameters_align[llama-llama-models-v2/TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5061624) test_e2e.py::test_openai_consistent_chat SKIP (https://nvbugs/5112075) -test_e2e.py::test_trtllm_bench_pytorch_backend_sanity SKIP (https://nvbugs/5345720) full:B200/examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-9b-it-fp8-bfloat16-8] SKIP (not supported on B200) full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_1gpus SKIP (not supported on B200) examples/test_medusa.py::test_mistral_medusa_1gpu[mistral-7b-v0.1] SKIP (https://nvbugs/5137575) diff --git a/tests/unittest/_torch/test_return_logits.py b/tests/unittest/_torch/test_return_logits.py index 8555b8bc5ab..2fa21ad4179 100644 --- a/tests/unittest/_torch/test_return_logits.py +++ b/tests/unittest/_torch/test_return_logits.py @@ -1,5 +1,4 @@ import os -import pickle import pytest import torch @@ -8,52 +7,12 @@ from tensorrt_llm import SamplingParams from tensorrt_llm._torch import LLM -from tensorrt_llm._torch.pyexecutor.llm_request import LlmResponse, PyResult -from tensorrt_llm.bindings.executor import Response, Result -from tensorrt_llm.executor.result import Logprob from tensorrt_llm.llmapi.llm_utils import BuildConfig, KvCacheConfig prompts = ["A B C"] global_kvcache_config = KvCacheConfig(max_tokens=10000) -def test_LlmResponse_pickle(): - result = Result() - result.decoding_iter = 1 - result.sequence_index = 1 - binding_response = Response(request_id=1, result=result, client_id=1) - py_result = PyResult(prompt_len=1, - max_new_tokens=1, - use_device_memory=True, - streaming=False, - return_log_probs=True, - return_context_logits=True, - return_generation_logits=True) - context_logits = torch.randn([1, 1, 128], device='cuda') - generation_logits = torch.randn([1, 1, 128], device='cuda') - logprobs = [[{1: Logprob(0.8, 1)}]] - py_result.append_context_logits(context_logits) - py_result.append_generation_logits(generation_logits) - py_result.append_log_probs(logprobs) - - response = LlmResponse(binding_response, py_result) - - data = pickle.dumps(response) - pickle_response: LlmResponse = pickle.loads(data) - - assert pickle_response._response.request_id == 1 - assert pickle_response._response.client_id == 1 - - pickle_result = pickle_response.result - - assert pickle_result.decoding_iter == 1 - assert pickle_result.sequence_index == 1 - assert torch.all(torch.eq(pickle_result.context_logits, context_logits)) - assert torch.all( - torch.eq(pickle_result.generation_logits, generation_logits)) - assert pickle_result.log_probs == logprobs - - @force_ampere # Save H100 resource @pytest.mark.parametrize("return_log_probs", [False, True]) @pytest.mark.parametrize("gather_generation_logits", [False, True])