-
Notifications
You must be signed in to change notification settings - Fork 1.9k
[None][feat] Add SM-level disaggregation support #9020
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Qiang Xu <[email protected]>
Signed-off-by: Qiang Xu <[email protected]>
Signed-off-by: Qiang Xu <[email protected]>
Signed-off-by: Qiang Xu <[email protected]>
📝 WalkthroughWalkthroughThis PR introduces softmax-disaggregated context processing by adding configuration structures, a separate context model engine, disaggregated schedulers, and context-aware request filtering to enable independent execution of context and generation phases across multiple engines. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant Executor as PyExecutor
participant ReqQueue as RequestQueue
participant MainSched as Scheduler
participant SmSched as SmDisaggCtxScheduler
participant CtxEngine as CtxModelEngine
participant MainEngine as MainModelEngine
Client->>Executor: enqueue_requests(requests)
Executor->>ReqQueue: fetch_new_requests(num_active_requests_on_engine)
alt SM Disaggregation Enabled
ReqQueue->>ReqQueue: _fetch_new_requests_sm_disagg()
ReqQueue->>MainSched: schedule_request()
MainSched->>SmSched: filter & process via SmDisaggCtxScheduler
SmSched->>SmSched: split context vs generation
Note over SmSched: Context phase
SmSched->>CtxEngine: forward context requests
Note over SmSched: Generation phase
SmSched->>MainEngine: forward generation requests
SmSched->>ReqQueue: return merged results
else Standard Path
ReqQueue->>MainSched: schedule_request()
MainSched->>MainEngine: forward all requests
end
ReqQueue-->>Executor: scheduled requests
Executor-->>Client: await_responses()
sequenceDiagram
participant Warmup
participant Engine1 as ModelEngine
participant Engine2 as CtxModelEngine
participant Engine3 as DraftEngine
participant Prof as Profiler
Warmup->>Engine1: set_is_warmup(True)
Warmup->>Engine2: set_is_warmup(True)
Warmup->>Engine3: set_is_warmup(True)
rect rgb(200, 220, 240)
Note over Prof: Profile Main Engine
Warmup->>Engine1: forward_step()
Engine1->>Prof: record metrics
end
rect rgb(220, 240, 200)
Note over Prof: Profile Context Engine
Warmup->>Engine2: forward_step()
Engine2->>Prof: record metrics
end
rect rgb(240, 220, 200)
Note over Prof: Profile Draft Engine
Warmup->>Engine3: forward_step()
Engine3->>Prof: record metrics
end
Warmup->>Engine1: set_is_warmup(False)
Warmup->>Engine2: set_is_warmup(False)
Warmup->>Engine3: set_is_warmup(False)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Areas requiring extra attention:
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025).All source files must start with the NVIDIA Apache-2.0 header per coding guidelines.
Apply at top of file:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # http://www.apache.org/licenses/LICENSE-2.0 + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License.tensorrt_llm/llmapi/llm_args.py (2)
1-1: Add NVIDIA Apache-2.0 header (2025).Please prepend the standard NVIDIA Apache-2.0 header.
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # http://www.apache.org/licenses/LICENSE-2.0 + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License.
2843-2848: Message nit: match the check.The validator allows 0; error message says “greater than 0”. Suggest “>= 0”.
- if self.batch_wait_timeout_ms < 0: - raise ValueError("batch_wait_timeout_ms must be greater than 0") + if self.batch_wait_timeout_ms < 0: + raise ValueError("batch_wait_timeout_ms must be >= 0")tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
1-1: Add NVIDIA Apache-2.0 header (2025).Please prepend the standard header.
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License at + # http://www.apache.org/licenses/LICENSE-2.0 + # Unless required by applicable law or agreed to in writing, software + # distributed under the License is distributed on an "AS IS" BASIS, + # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + # See the License for the specific language governing permissions and + # limitations under the License.
🧹 Nitpick comments (6)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
53-55: SM-disagg wiring looks fine; document semantics.Constructor and self.is_sm_disagg flag are clear. Please document in the class docstring what “SM-disaggregation mode” implies for fetch capacity and scheduling so future readers don’t confuse it with network disaggregation.
Also applies to: 64-64
tensorrt_llm/llmapi/llm_args.py (2)
318-335: SmDisaggConfig: clarify bounds and behavior in docstring.Suggest noting valid range for context_sm_percent and that non-positive ctx limits inherit generation limits. No functional change needed.
2888-2900: Support YAML ingestion for sm_disagg_config.Add SmDisaggConfig to field_mapping so dicts in extra_llm_api_options are parsed consistently (mirrors other nested configs).
field_mapping = { "quant_config": QuantConfig, "calib_config": CalibConfig, "build_config": BuildConfig, "decoding_config": DecodingConfig, "enable_build_cache": BuildCacheConfig, "speculative_config": DecodingBaseConfig, "lora_config": LoraConfig, "moe_config": MoeConfig, "attention_dp_config": AttentionDpConfig, "sparse_attention_config": BaseSparseAttentionConfig, + "sm_disagg_config": SmDisaggConfig, }tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
1363-1371: Operator precedence: parenthesize ‘and’ for clarity (RUF021).Make the new SM-disagg condition explicit.
- if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ - or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: + if (next_draft_tokens_device is None or request.is_dummy + or request.py_batch_idx is None + or (self.sm_disagg_enabled + and request.max_num_generated_tokens == 0)):
1470-1472: Operator precedence: parenthesize ‘and’ for clarity (RUF021).Same as above in generation path.
- if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ - or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: + if (new_tokens_device is None or request.is_dummy + or request.py_batch_idx is None + or (self.sm_disagg_enabled + and request.max_num_generated_tokens == 0)):tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
352-363: Consider extracting validation logic to a helper function.The validation checks are correct but could improve maintainability by extracting to a helper function like
_validate_sm_disagg_config(llm_args). This would make the main flow cleaner and the validation logic more testable.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
tensorrt_llm/_torch/pyexecutor/_util.py(5 hunks)tensorrt_llm/_torch/pyexecutor/executor_request_queue.py(3 hunks)tensorrt_llm/_torch/pyexecutor/llm_request.py(1 hunks)tensorrt_llm/_torch/pyexecutor/model_engine.py(6 hunks)tensorrt_llm/_torch/pyexecutor/model_loader.py(3 hunks)tensorrt_llm/_torch/pyexecutor/py_executor.py(24 hunks)tensorrt_llm/_torch/pyexecutor/py_executor_creator.py(7 hunks)tensorrt_llm/_torch/pyexecutor/scheduler.py(2 hunks)tensorrt_llm/_torch/virtual_memory.py(1 hunks)tensorrt_llm/llmapi/__init__.py(2 hunks)tensorrt_llm/llmapi/llm_args.py(3 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/virtual_memory.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/scheduler.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/executor_request_queue.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/virtual_memory.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/scheduler.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/executor_request_queue.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/virtual_memory.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/_torch/pyexecutor/scheduler.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/executor_request_queue.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧠 Learnings (9)
📚 Learning: 2025-07-22T09:22:14.726Z
Learnt from: yechank-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.
Applied to files:
tensorrt_llm/_torch/pyexecutor/llm_request.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/llmapi/llm_args.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM's bench configuration, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which is a Dict[str, Any] that can contain default values including `cuda_graph_config`, making the fallback `llm_args["cuda_graph_config"]` safe to use.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.pytensorrt_llm/llmapi/__init__.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation.
Applied to files:
tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.
Applied to files:
tensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-28T10:22:02.288Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/pyexecutor/sampler.py:1191-1197
Timestamp: 2025-08-28T10:22:02.288Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, the object identity comparison `softmax_req_indices is not group_req_indices_cuda` on line ~1191 is intentional and used as an optimization to determine whether to reuse an existing indexer or create a new one, based on which code path was taken during tensor assignment.
Applied to files:
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
📚 Learning: 2025-08-26T06:07:02.166Z
Learnt from: shaharmor98
Repo: NVIDIA/TensorRT-LLM PR: 7231
File: tensorrt_llm/_torch/pyexecutor/_util.py:504-509
Timestamp: 2025-08-26T06:07:02.166Z
Learning: In tensorrt_llm/_torch/pyexecutor/_util.py, when calling model_engine.set_lora_model_config(), pass model_binding_config.mlp_hidden_size directly without multiplying by mapping.tp_size, as the mlp_hidden_size from get_bindings_model_config() is already the per-TP rank value needed for LoRA weight packaging.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_loader.pytensorrt_llm/_torch/pyexecutor/model_engine.pytensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-09-03T13:16:06.824Z
Learnt from: nvpohanh
Repo: NVIDIA/TensorRT-LLM PR: 7478
File: tensorrt_llm/_torch/models/modeling_llama.py:1315-1315
Timestamp: 2025-09-03T13:16:06.824Z
Learning: The Llama4VisionEncoder.load_weights method signature is `def load_weights(self, weights: Dict)` and should not be confused with Llama4ForConditionalGeneration.load_weights which has a different signature including weight_mapper parameter.
Applied to files:
tensorrt_llm/_torch/pyexecutor/model_loader.py
🧬 Code graph analysis (9)
tensorrt_llm/_torch/pyexecutor/_util.py (4)
tensorrt_llm/llmapi/llm_args.py (1)
SmDisaggConfig(318-338)tensorrt_llm/_torch/pyexecutor/scheduler.py (4)
SimpleScheduler(198-218)SmDisaggCtxScheduler(221-243)BindCapacityScheduler(72-99)BindMicroBatchScheduler(171-195)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
PyTorchModelEngine(128-2549)tensorrt_llm/mapping.py (1)
has_pp(254-255)
tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
LlmRequest(423-643)get_context_requests(802-803)tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py (1)
SchedulerOutput(76-81)
tensorrt_llm/llmapi/__init__.py (1)
tensorrt_llm/llmapi/llm_args.py (2)
SmDisaggConfig(318-338)TorchLlmArgs(2427-2880)
tensorrt_llm/_torch/pyexecutor/py_executor.py (5)
tensorrt_llm/_torch/pyexecutor/llm_request.py (5)
get_context_requests(802-803)get_draft_token_length(788-799)get_generation_requests(806-807)get(129-141)LlmRequest(423-643)tensorrt_llm/_torch/pyexecutor/scheduler.py (7)
RequestScheduler(44-55)schedule_request(47-55)schedule_request(61-69)schedule_request(95-99)schedule_request(112-153)schedule_request(206-218)schedule_request(229-243)tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
ModelEngine(69-92)forward(76-84)forward(2286-2387)tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
fetch_new_requests(337-347)tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
prepare_resources(81-82)prepare_resources(407-447)prepare_resources(1307-1310)prepare_resources(1432-1448)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest(423-643)
tensorrt_llm/_torch/pyexecutor/model_loader.py (1)
examples/models/core/enc_dec/convert_checkpoint.py (1)
state_dict(1629-1630)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
drafting_loop_wrapper(394-400)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
is_dummy(611-612)
tensorrt_llm/llmapi/llm_args.py (4)
tensorrt_llm/builder.py (1)
default(45-50)tensorrt_llm/models/modeling_utils.py (3)
from_dict(253-263)from_dict(325-334)from_dict(487-492)tests/unittest/api_stability/api_stability_core.py (3)
from_dict(116-123)from_dict(172-178)from_dict(319-328)tensorrt_llm/mapping.py (1)
from_dict(314-315)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (4)
tensorrt_llm/_torch/virtual_memory.py (1)
ExecutorMemoryType(70-82)tensorrt_llm/llmapi/llm_args.py (4)
parallel_config(1775-1776)world_size(372-373)world_size(382-386)CapacitySchedulerPolicy(1077-1083)tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
PyTorchModelEngine(128-2549)tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
attn_metadata(131-132)
🪛 Ruff (0.14.3)
tensorrt_llm/_torch/pyexecutor/model_engine.py
1365-1365: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
1471-1471: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear
Parenthesize the and subexpression
(RUF021)
tensorrt_llm/llmapi/llm_args.py
2833-2835: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
354-356: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (16)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
205-206: Weight sharing passthrough: LGTM.Passing weight_sharing_model into ModelLoader aligns with shared-weights engines.
154-157: Review comment references incorrect line numbers; underlying concern is unfounded.The review points to lines 154-157 and 175, which do not contain
max_num_generated_tokensaccess. Lines 154-157 initialize SM disagg config values, and line 175 checkssm_disagg_enabled(a boolean flag). The actualmax_num_generated_tokensaccess occurs at lines 1365 and 1471, where it is accessed directly withoutgetattr().Evidence shows the attribute is guaranteed to exist:
- Accessed directly throughout the codebase (lines 1365, 1471, plus test cases at 367, 379) without errors or defensive checks
- Inherited from C++ bindings (
tensorrt_llm.bindings.internal.batch_manager.LlmRequest) where it is always initialized- Tests confirm consistent availability
The existing code pattern validates that no
getattr()wrapper is needed.Likely an incorrect or invalid review comment.
tensorrt_llm/llmapi/llm_args.py (1)
2464-2468: No action needed—SmDisaggConfig is already properly exported.The verification confirms SmDisaggConfig is imported on line 17 and added to the
__all__export list on line 65, making it available for downstream code to import directly from tensorrt_llm/llmapi.tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
338-348: All call sites verified—no issues found.The single call site at
py_executor.py:1819correctly passes both required parameters (activate_requestsandnum_active_requests_on_engine) with matching types and order. No runtime errors will occur from missing parameters.tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (4)
81-104: LGTM! Memory monitoring support for context model is well-integrated.The additions for
MODEL_ENGINE_CTXfollow the existing patterns and provide helpful tuning guidance consistent with the main model engine.
474-476: LGTM! Safe iteration with proper None checks.The loop correctly checks for
Nonebefore accessing engine attributes.
666-667: LGTM! Consistent propagation of SM-disagg parameters.Both calls to
create_py_executor_instancecorrectly passctx_model_engineandsm_disagg_config.Also applies to: 725-726
696-702: LGTM! Proper resource cleanup with safe guards.The loop correctly handles
Noneengines and checks bothattn_metadataandcuda_graph_runner.enabledbefore releasing resources.tensorrt_llm/_torch/pyexecutor/py_executor.py (8)
50-51: LGTM! Clean imports and well-documented configuration.The new imports and environment variable follow existing conventions and are clearly documented.
Also applies to: 61-63
124-125: LGTM! Proper initialization with correct synchronization primitives.The new parameters, fields, and synchronization primitives (lock and condition variables) are correctly initialized. Using the same lock (
sm_disagg_lock) for both condition variables is the right approach for coordinating the two executor loops.Also applies to: 164-168, 186-187, 217-219
254-260: LGTM! Consistent multi-engine handling with proper guards.The warmup, property setter, and cleanup operations correctly iterate over all engines with appropriate
Nonechecks.Also applies to: 373-378, 473-474
537-647: LGTM! Well-designed profiler enhancements for multi-engine support.The
_profilercontext manager is properly extended to support:
- Multiple engines via
model_engineparameter with safe default- Per-stream profiling via
streamparameter- Phase identification via
phase_namefor clearer logs- Profiling toggle via
enable_profilerparameterThe changes are backward compatible and follow the existing patterns.
1513-1694: Verify synchronization correctness under edge cases.The synchronization between context and generation loops uses condition variables correctly for the common case:
- Context waits on
gen_request_cvwhen it has pending work but can't schedule (line 1555)- Generation waits on
ctx_request_cvwhen it has no work (line 1641)- Each notifies the other after processing
However, verify behavior in these edge cases:
Startup: When both loops start with no scheduled work, does generation block indefinitely on
ctx_request_cv? The context loop will continue and eventually fetch requests, but there's a window where generation is blocked.Resource starvation: If context loop can't schedule due to resource constraints (line 1554 condition is true), and generation is also blocked, ensure one will eventually make progress.
Consider adding a timeout to the condition variable waits or additional logging to help diagnose potential hangs in production.
To verify the synchronization logic is sound, you could add assertion checks or logging:
# In context loop before waiting if scheduled_batch.batch_size == 0 and (len(ctx_requests) > 0 or self.executor_request_queue.get_waiting_queue_size() > 0): logger.debug(f"Context loop waiting for generation to finish. Pending context requests: {len(ctx_requests)}") self.gen_request_cv.wait() # In generation loop before waiting if scheduled_batch.batch_size == 0: logger.debug(f"Generation loop waiting for context to finish. Active requests: {len(self.active_requests)}") self.ctx_request_cv.wait()
1814-1820: LGTM! Correct capacity calculation for SM-disaggregation.When
ctx_model_engineis present, the calculation correctly counts only context requests since that's what's relevant for the context engine's capacity. This aligns with the SM-disaggregation design where context and generation phases have separate capacity limits.
1914-1922: LGTM! Flexible scheduling interface with safe defaults.The optional
schedulerandactive_requestsparameters enable disaggregated scheduling while maintaining backward compatibility through sensible defaults.
2154-2171: LGTM! Flexible forward step supporting multiple engines.The optional
model_engineparameter enables forwarding through either the main or context engine while preserving backward compatibility. The implementation correctly defaults toself.model_engineand uses the provided engine consistently throughout.
| def _fetch_new_requests_sm_disagg( | ||
| self, num_active_requests: int, | ||
| num_active_requests_on_engine: int) -> List[LlmRequest]: | ||
| """Handle SM-level disaggregation request fetching.""" | ||
| total_max_num_active_requests = (self.max_num_active_requests + | ||
| num_active_requests - | ||
| num_active_requests_on_engine) | ||
|
|
||
| # fetch and process requests into waiting queue | ||
| new_requests = self._fetch_and_process_requests( | ||
| num_active_requests_on_engine, | ||
| total_max_num_active_requests, | ||
| enable_attention_dp=False) | ||
|
|
||
| # Merge requests and add to active list | ||
| merged_requests = self._merge_requests(new_requests) | ||
| return merged_requests | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Bug: SM-disagg capacity calculation can oversubscribe total active requests.
As written, total_max_num_active_requests = max + num_active_requests - num_active_requests_on_engine allows this engine to claim more than the global remaining capacity if num_active_requests represents a cross-engine total. The safe pattern is:
- remaining = max(0, max_num_active_requests - total_active_across_both_engines)
- allow this engine to fetch at most ‘remaining’ new requests: set total_max_for_this_engine = current_on_this_engine + remaining
This keeps the combined active count bounded by max_num_active_requests.
Apply:
- def _fetch_new_requests_sm_disagg(
- self, num_active_requests: int,
- num_active_requests_on_engine: int) -> List[LlmRequest]:
+ def _fetch_new_requests_sm_disagg(
+ self, num_active_requests: int,
+ num_active_requests_on_engine: int) -> List[LlmRequest]:
"""Handle SM-level disaggregation request fetching."""
- total_max_num_active_requests = (self.max_num_active_requests +
- num_active_requests -
- num_active_requests_on_engine)
+ # num_active_requests is expected to be the cross-engine total when available.
+ # Compute remaining global slots and cap this engine accordingly.
+ remaining_global_slots = max(
+ 0, self.max_num_active_requests - num_active_requests)
+ total_max_num_active_requests = (num_active_requests_on_engine +
+ remaining_global_slots)
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
- num_active_requests_on_engine,
- total_max_num_active_requests,
- enable_attention_dp=False)
+ num_active_requests_on_engine,
+ total_max_num_active_requests,
+ enable_attention_dp=False)
# Merge requests and add to active list
merged_requests = self._merge_requests(new_requests)
return merged_requestsIf num_active_requests actually equals num_active_requests_on_engine at call sites today, this change is a no-op; it future-proofs the logic when passing cross-engine totals.
Confirm the intended meaning of num_active_requests (engine-local vs cross-engine) at call sites.
I'll verify this review comment by examining the call sites and parameter semantics to understand whether the capacity calculation bug exists.
🏁 Script executed:
#!/bin/bash
# Find all call sites of _fetch_new_requests_sm_disagg
echo "=== Call sites of _fetch_new_requests_sm_disagg ==="
rg -n "_fetch_new_requests_sm_disagg" --type=py -C 5
echo ""
echo "=== Full method implementation context ==="
# Get the file and surrounding context
fd -e py "executor_request_queue.py" --exec head -n 400 {} \; | tail -n 100Length of output: 5320
I need to understand the semantics of activate_requests and who calls fetch_new_requests. Let me investigate further.
🏁 Script executed:
#!/bin/bash
# Find all callers of fetch_new_requests
echo "=== Callers of fetch_new_requests ==="
rg -n "\.fetch_new_requests\(" --type=py -B 3 -A 3
echo ""
echo "=== Class definition and __init__ ==="
rg -n "class ExecutorRequestQueue" --type=py -A 30
echo ""
echo "=== Method _fetch_and_process_requests signature ==="
rg -n "def _fetch_and_process_requests" --type=py -B 2 -A 15Length of output: 6186
I need to understand the SM-disagg semantics and what get_context_requests does to confirm the bug:
🏁 Script executed:
#!/bin/bash
# Find get_context_requests definition and usage
echo "=== get_context_requests definition ==="
rg -n "def get_context_requests" --type=py -A 10
echo ""
echo "=== SM-disagg context and comments ==="
rg -n "sm_disagg\|SM.disagg" --type=py -B 2 -A 2
echo ""
echo "=== Complete method _fetch_new_requests_sm_disagg ==="
rg -n "_fetch_new_requests_sm_disagg" --type=py -A 25 | head -n 50Length of output: 4249
Confirm the SM-disagg capacity calculation fix is necessary.
Based on verification:
num_active_requestspassed to_fetch_new_requests_sm_disaggequalslen(self.active_requests)— the total active requests across all phases.num_active_requests_on_engineequals the count of context-phase requests only (subset obtained viaget_context_requests).- The current formula
max + num_active_requests - num_active_requests_on_enginecan exceed the global limit. For example: with max=100, total_active=80, context_only=30, the formula yields 150. - The proposed fix correctly computes remaining global capacity and ensures the engine cannot claim more than its fair share.
def _fetch_new_requests_sm_disagg(
self, num_active_requests: int,
num_active_requests_on_engine: int) -> List[LlmRequest]:
"""Handle SM-level disaggregation request fetching."""
- total_max_num_active_requests = (self.max_num_active_requests +
- num_active_requests -
- num_active_requests_on_engine)
+ # num_active_requests is the total active across all phases.
+ # Compute remaining global slots and cap this engine accordingly.
+ remaining_global_slots = max(
+ 0, self.max_num_active_requests - num_active_requests)
+ total_max_num_active_requests = (num_active_requests_on_engine +
+ remaining_global_slots)
# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
num_active_requests_on_engine,
total_max_num_active_requests,
enable_attention_dp=False)
# Merge requests and add to active list
merged_requests = self._merge_requests(new_requests)
return merged_requests🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/executor_request_queue.py around lines 349 to
366, the current SM-disagg capacity calculation total_max_num_active_requests =
(self.max_num_active_requests + num_active_requests -
num_active_requests_on_engine) can exceed the global limit; replace it with a
computation of remaining global capacity: remaining =
self.max_num_active_requests - (num_active_requests -
num_active_requests_on_engine); clamp remaining to >=0; then set
total_max_num_active_requests = num_active_requests_on_engine + remaining so the
engine only claims up to the remaining global capacity and the global max is
never exceeded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like reasonable feedback.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pcastonguay, great! The suggested fix correctly addresses the capacity oversubscription issue by computing the remaining global slots and ensuring the engine-specific limit respects the global maximum.
The key improvement is:
remaining_global_slots = max(0, self.max_num_active_requests - num_active_requests)
total_max_num_active_requests = num_active_requests_on_engine + remaining_global_slotsThis ensures that even when num_active_requests represents the cross-engine total, the combined active count will never exceed max_num_active_requests. Let me know if you need any clarification on the implementation! 🐰
| is_sm_disagg_ctx_phase: bool = False, | ||
| is_draft_model: bool = False, | ||
| drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], | ||
| torch.nn.Module]] = None, | ||
| model: Optional[torch.nn.Module] = None, | ||
| weight_sharing_model: Optional[torch.nn.Module] = None, | ||
| ): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Constructor additions look good; assert config presence when ctx-phase is enabled.
Prevent AttributeError if is_sm_disagg_ctx_phase=True but sm_disagg_config is None.
spec_config: Optional["DecodingBaseConfig"] = None,
is_sm_disagg_ctx_phase: bool = False,
is_draft_model: bool = False,
@@
) = llm_args.get_runtime_sizes()
- if is_sm_disagg_ctx_phase:
- max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
- max_batch_size = llm_args.sm_disagg_config.context_max_batch_size
+ if is_sm_disagg_ctx_phase:
+ if llm_args.sm_disagg_config is None:
+ raise ValueError(
+ "is_sm_disagg_ctx_phase=True requires sm_disagg_config"
+ )
+ max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
+ max_batch_size = llm_args.sm_disagg_config.context_max_batch_sizeCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 139 to 145, when
is_sm_disagg_ctx_phase=True the constructor may later access sm_disagg_config
and raise AttributeError if it's None; add an explicit check at construction
start that if is_sm_disagg_ctx_phase is True then sm_disagg_config is not None,
and raise a clear ValueError or use assert with a descriptive message indicating
sm_disagg_config is required for SM disaggregation context phase so callers get
an immediate, informative failure.
| if self.weight_sharing_model is not None: | ||
| model.load_state_dict(self.weight_sharing_model.state_dict(), | ||
| assign=True) | ||
| # Free up duplicate model weights allocated before weight sharing | ||
| torch.cuda.empty_cache() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep shared weights on-device when assigning
state_dict() without keep_vars=True produces detached CPU tensors. With assign=True, those CPU tensors replace the module’s CUDA parameters, so this branch forces the newly built engine to run with CPU weights and immediately triggers device-mismatch failures instead of sharing memory. Please grab the on-device Parameter objects before assigning.
- model.load_state_dict(self.weight_sharing_model.state_dict(),
- assign=True)
+ shared_state = self.weight_sharing_model.state_dict(
+ keep_vars=True)
+ model.load_state_dict(shared_state, assign=True)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if self.weight_sharing_model is not None: | |
| model.load_state_dict(self.weight_sharing_model.state_dict(), | |
| assign=True) | |
| # Free up duplicate model weights allocated before weight sharing | |
| torch.cuda.empty_cache() | |
| if self.weight_sharing_model is not None: | |
| shared_state = self.weight_sharing_model.state_dict( | |
| keep_vars=True) | |
| model.load_state_dict(shared_state, assign=True) | |
| # Free up duplicate model weights allocated before weight sharing | |
| torch.cuda.empty_cache() |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_loader.py around lines 312 to 316, the
code calls self.weight_sharing_model.state_dict() which returns detached CPU
tensors and then uses assign=True, causing CPU tensors to replace CUDA
parameters; instead obtain the on-device Parameter objects by calling
state_dict(keep_vars=True) (or otherwise capture the
weight_sharing_model.parameters()/buffers as Variables on their current device)
and pass that mapping into model.load_state_dict(..., assign=True); ensure any
torch.cuda.empty_cache() call happens after assignment if needed.
| if llm_args.sm_disagg_config is not None: | ||
| if llm_args.cache_transceiver_config is not None: | ||
| raise ValueError( | ||
| "SM-level disaggregation is not compatible with disaggregated serving." | ||
| ) | ||
| if llm_args.parallel_config.world_size > 1: | ||
| raise NotImplementedError( | ||
| "SM-level disaggregation is not supported with parallelism.") | ||
| if scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: | ||
| raise NotImplementedError( | ||
| "SM-level disaggregation is only supported with guaranteed no evict scheduler policy." | ||
| ) | ||
|
|
||
| with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX, | ||
| RestoreMode.PINNED): | ||
| ctx_llm_args = copy.copy(llm_args) | ||
| ctx_llm_args.cuda_graph_config = None | ||
| ctx_model_engine = PyTorchModelEngine( | ||
| model_path=checkpoint_dir, | ||
| llm_args=ctx_llm_args, | ||
| mapping=mapping, | ||
| attn_runtime_features=attn_runtime_features, | ||
| dist=dist, | ||
| spec_config=spec_config, | ||
| weight_sharing_model=model_engine.model, | ||
| ) | ||
| else: | ||
| ctx_model_engine = None | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical: Missing is_sm_disagg_ctx_phase=True flag and potential config sharing issue.
Two critical issues:
-
Missing phase flag: Line 369-376 should pass
is_sm_disagg_ctx_phase=TruetoPyTorchModelEngine. According tomodel_engine.py(lines 127-210 in snippets), this flag controls whether to usesm_disagg_config.context_max_num_tokensandsm_disagg_config.context_max_batch_size. Without it, the context model will use the wrong capacity limits. -
Shallow copy risk: Line 368 uses
copy.copy(llm_args)which creates a shallow copy. Ifllm_argscontains mutable objects (e.g., nested configs), modifications to those objects will affect bothllm_argsandctx_llm_args. While settingcuda_graph_config = Noneon line 369 is safe if it replaces a reference, other mutable fields could be problematic.
Apply this diff:
ctx_llm_args = copy.copy(llm_args)
ctx_llm_args.cuda_graph_config = None
ctx_model_engine = PyTorchModelEngine(
model_path=checkpoint_dir,
llm_args=ctx_llm_args,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=spec_config,
+ is_sm_disagg_ctx_phase=True,
weight_sharing_model=model_engine.model,
)Consider using copy.deepcopy(llm_args) instead of copy.copy(llm_args) at line 368 to avoid unintended config sharing between the main and context engines.
🧰 Tools
🪛 Ruff (0.14.3)
354-356: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py around lines 352 to 380:
the context model engine is created without the is_sm_disagg_ctx_phase=True flag
and uses a shallow copy of llm_args; change copy.copy(llm_args) to
copy.deepcopy(llm_args) to avoid shared mutable state, set
ctx_llm_args.cuda_graph_config = None as before, and pass
is_sm_disagg_ctx_phase=True into the PyTorchModelEngine constructor so the
context engine uses sm_disagg_config context limits.
| def split_device_green_ctx(self): | ||
| device = torch.device("cuda", self.device_id) | ||
| device_properties = torch.cuda.get_device_properties(device) | ||
| sm_count = device_properties.multi_processor_count | ||
| if device_properties.major >= 9: | ||
| sm_min = 8 | ||
| sm_align = 8 | ||
| else: | ||
| sm_min = 4 if device_properties.major == 8 else 2 | ||
| sm_align = 2 | ||
|
|
||
| from flashinfer import green_ctx | ||
|
|
||
| def split_device_green_ctx_aligned(sm_s1): | ||
| sm_s1 = round(sm_s1 / sm_align) * sm_align | ||
| sm_s1 = min(max(sm_s1, sm_min), sm_count - sm_min) | ||
| return green_ctx.split_device_green_ctx_by_sm_count(device, [sm_s1]) | ||
|
|
||
| sm_ctx = round(sm_count * self.sm_disagg_ctx_sm_percent) | ||
| sm_gen = sm_count - sm_ctx | ||
| # Choose the split closer to user-specified percentage when sm_count is not divisible by sm_align | ||
| sm_ctx_dist = min(sm_ctx % sm_align, sm_align - (sm_ctx % sm_align)) | ||
| sm_gen_dist = min(sm_gen % sm_align, sm_align - (sm_gen % sm_align)) | ||
| if sm_gen_dist < sm_ctx_dist: | ||
| (stream_gen, | ||
| stream_ctx), (res_gen, | ||
| res_ctx) = split_device_green_ctx_aligned(sm_gen) | ||
| else: | ||
| (stream_ctx, | ||
| stream_gen), (res_ctx, | ||
| res_gen) = split_device_green_ctx_aligned(sm_ctx) | ||
| logger.info( | ||
| f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase." | ||
| ) | ||
| return stream_ctx, stream_gen |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
Move flashinfer import to module level for consistency and early failure detection.
Line 1718 imports flashinfer.green_ctx inside the split_device_green_ctx method. This is problematic because:
- Late failure: If flashinfer is not available or incompatible, the error won't surface until this method is called during execution, making debugging harder.
- Inconsistent style: Other imports in this file are at the module level.
- Repeated import overhead: Although Python caches imports, it's better to import once at module level.
Move the import to the top of the file:
from tensorrt_llm._torch.pyexecutor.resource_manager import (
ResourceManagerType, request_context)
+try:
+ from flashinfer import green_ctx
+except ImportError:
+ green_ctx = None
from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
mpi_disabled, nvtx_range, trace_func)Then add a validation check in __init__ or at the start of split_device_green_ctx:
def split_device_green_ctx(self):
if green_ctx is None:
raise ImportError("flashinfer is required for SM-level disaggregation but is not available")
device = torch.device("cuda", self.device_id)
# ... rest of the method🤖 Prompt for AI Agents
tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 1707 to 1741: the
function currently imports flashinfer.green_ctx inside split_device_green_ctx
causing late failures and inconsistent style; move "from flashinfer import
green_ctx" to the module-level imports at top of the file, then add a validation
(in __init__ or at the start of split_device_green_ctx) that raises ImportError
with a clear message if green_ctx is None or missing so callers fail fast;
ensure no other behavior changes and remove the local import from this method.
| @model_validator(mode='after') | ||
| def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs': | ||
| """Validate SM-level disaggregation configuration.""" | ||
| if self.sm_disagg_config is None: | ||
| return self | ||
|
|
||
| config = self.sm_disagg_config | ||
| if not 0 < config.context_sm_percent < 1: | ||
| raise ValueError( | ||
| "sm_disagg_config.context_sm_percent must be in the range (0, 1)" | ||
| ) | ||
| if config.context_max_num_tokens <= 0: | ||
| config.context_max_num_tokens = self.max_num_tokens | ||
| if config.context_max_batch_size <= 0: | ||
| config.context_max_batch_size = self.max_batch_size | ||
| return self | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow 1.0 for context_sm_percent and clamp ctx limits to runtime maxima.
- The current check excludes 1.0; allowing full SM allocation to context is reasonable.
- Also clamp context_max_num_tokens/batch_size to runtime maxima to avoid illegal overprovisioning.
Apply:
def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs':
"""Validate SM-level disaggregation configuration."""
if self.sm_disagg_config is None:
return self
config = self.sm_disagg_config
- if not 0 < config.context_sm_percent < 1:
+ if not (0 < config.context_sm_percent <= 1):
raise ValueError(
- "sm_disagg_config.context_sm_percent must be in the range (0, 1)"
+ "sm_disagg_config.context_sm_percent must be in the range (0, 1]"
)
if config.context_max_num_tokens <= 0:
config.context_max_num_tokens = self.max_num_tokens
if config.context_max_batch_size <= 0:
config.context_max_batch_size = self.max_batch_size
+ # Clamp to runtime maxima if provided
+ if self.max_num_tokens is not None and config.context_max_num_tokens is not None:
+ config.context_max_num_tokens = min(config.context_max_num_tokens,
+ self.max_num_tokens)
+ if self.max_batch_size is not None and config.context_max_batch_size is not None:
+ config.context_max_batch_size = min(config.context_max_batch_size,
+ self.max_batch_size)
return self🧰 Tools
🪛 Ruff (0.14.3)
2833-2835: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
tensorrt_llm/llmapi/llm_args.py lines 2825-2841: update the validator to allow
context_sm_percent == 1.0 (change the range check to require 0 <
context_sm_percent <= 1) and ensure context_max_num_tokens and
context_max_batch_size are clamped to runtime maxima — if each is <= 0 set it to
self.max_num_tokens / self.max_batch_size respectively, otherwise set it to
min(their value, self.max_num_tokens/self.max_batch_size) to prevent
overprovisioning.
|
/bot run |
|
PR_Github #23881 [ run ] triggered by Bot. Commit: |
|
Does this PR support dynamically allocating SM resources between the ctx and gen engines based on the per-batch workload? If so, is there a cost model used to determine the SM partitioning? |
No. A fixed SM ratio is passed through llm args. The fixed ratio seems to be good enough for the workloads we looked into (including OpenOrca which has varying ISL/OSL per request). So dynamic SM allocation is not planned unless there's a compelling use case. |
|
PR_Github #23881 [ run ] completed with state |
| sm_min = 4 if device_properties.major == 8 else 2 | ||
| sm_align = 2 | ||
|
|
||
| from flashinfer import green_ctx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Though TRTLLM already integrates FlashInfer, is it a better idea to directly integrate the low-level green_context API provided by CUDA to reduce the abstraction level here?
|
|
||
| self._kv_connector_terminate_requests() | ||
|
|
||
| def _executor_loop_sm_disagg_ctx(self, stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a large amount of duplicated code from the main executor loop. Could we refactor to avoid so much duplication? This duplicated code makes it very hard to maintain.
| iter_stats=iter_stats, | ||
| iter_start_time=iter_start_time)) | ||
|
|
||
| def _executor_loop_sm_disagg_gen_overlap(self, stream): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a large amount of duplicated code from the main executor loop with overlap. Could we refactor to avoid so much duplication? This duplicated code makes it very hard to maintain.
| else: | ||
| return self._fetch_new_requests_attention_tp(len(activate_requests)) | ||
|
|
||
| def _fetch_new_requests_sm_disagg( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should add unit tests for this new function.
| new_requests = self._fetch_and_process_requests( | ||
| num_active_requests_on_engine, | ||
| total_max_num_active_requests, | ||
| enable_attention_dp=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we simply setting this to False? If we don't support attention dp, we should raise an error.
| def _fetch_new_requests_sm_disagg( | ||
| self, num_active_requests: int, | ||
| num_active_requests_on_engine: int) -> List[LlmRequest]: | ||
| """Handle SM-level disaggregation request fetching.""" | ||
| total_max_num_active_requests = (self.max_num_active_requests + | ||
| num_active_requests - | ||
| num_active_requests_on_engine) | ||
|
|
||
| # fetch and process requests into waiting queue | ||
| new_requests = self._fetch_and_process_requests( | ||
| num_active_requests_on_engine, | ||
| total_max_num_active_requests, | ||
| enable_attention_dp=False) | ||
|
|
||
| # Merge requests and add to active list | ||
| merged_requests = self._merge_requests(new_requests) | ||
| return merged_requests | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like reasonable feedback.
| self.sm_disagg_lock = threading.Lock() | ||
| self.ctx_request_cv = threading.Condition(self.sm_disagg_lock) | ||
| self.gen_request_cv = threading.Condition(self.sm_disagg_lock) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should those get initialized only if ctx_model_engine is not None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need to add integration/unit tests to validate the new code.
Summary by CodeRabbit
Release Notes
Description
Background and overview
The SOTA request scheduling scheme - chunked prefill - piggybacks chunked context requests with generation requests to achieve stable TPOT and improved GPU utilization. However, the TPOT latency is bloated as the generation requests now often need to be processed together with many more context tokens.
Alternatively, disaggregated serving processes context and generation requests on different nodes to achieve low TOPT. However, the generation nodes can be underutilized, especially when using smaller batch sizes to meet the TPOT target. Besides, there are deployment scenarios where disaggregated serving is not suitable, e.g., lack of high speed interconnect.
This PR aims to achieve better throghput@latency by implementing a new feature called SM-level disaggregation. The feature achieves low TPOT latency by decoupling context and generation requests (as in desegregated serving), but still runs the decoupled requests on the same GPU for better GPU utilization (as in chunked prefill). To achieve that, we partition the GPU (using Green Contexts) and allocate SMs to context and generation phases, and the context and generation requests are asynchronously scheduled onto two different streams on the same GPU.
This is the first PR that implements the core functionality of SM-level disaggregation. Follow-up PRs are planned to address doc and examples, persistent kernel perf issues, and parallelism support.
Design
Performance results
Limitation: Note that the context and generation workloads are relatively balanced in above cases (in terms of compute time). This feature (as well as any other disaggregation schemes) won't show much benefit if the workload is dominated by context or generation.
Test Coverage
Please suggest appropriate test cases to guard the feature.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.