Skip to content

Conversation

@qiangxu1996
Copy link

@qiangxu1996 qiangxu1996 commented Nov 7, 2025

Summary by CodeRabbit

Release Notes

  • New Features
    • Added support for SM-level disaggregation mode to optimize context execution with configurable parameters including context execution percentage, token limits, and batch size settings.
    • Enabled context-disaggregated execution capabilities for improved performance optimization.

Description

Background and overview

The SOTA request scheduling scheme - chunked prefill - piggybacks chunked context requests with generation requests to achieve stable TPOT and improved GPU utilization. However, the TPOT latency is bloated as the generation requests now often need to be processed together with many more context tokens.

Alternatively, disaggregated serving processes context and generation requests on different nodes to achieve low TOPT. However, the generation nodes can be underutilized, especially when using smaller batch sizes to meet the TPOT target. Besides, there are deployment scenarios where disaggregated serving is not suitable, e.g., lack of high speed interconnect.

This PR aims to achieve better throghput@latency by implementing a new feature called SM-level disaggregation. The feature achieves low TPOT latency by decoupling context and generation requests (as in desegregated serving), but still runs the decoupled requests on the same GPU for better GPU utilization (as in chunked prefill). To achieve that, we partition the GPU (using Green Contexts) and allocate SMs to context and generation phases, and the context and generation requests are asynchronously scheduled onto two different streams on the same GPU.

image

This is the first PR that implements the core functionality of SM-level disaggregation. Follow-up PRs are planned to address doc and examples, persistent kernel perf issues, and parallelism support.

Design

image
  • The Python executor starts two executor loops if the feature is enabled, one for context phase and one for generation phase
  • Two model engines are instantiated (but model weights are still shared) as they are non-reentrant
  • Handing of shared variables/resources (active_requests and KV cache) is lock protected

Performance results

  • GPT-OSS-120B, ISL/OSL = 10k/600 on 1x B200: Up to 26.2% improvement on user throughput or up to 8.9% improvement on GPU throughput over chunked prefill. Disaggregated serving has the same perf as chunked prefill in this case.
image
  • Llama 2 70B, OpenOrca (avg ISL/OSL = 221/276) on 1x B200: Up to 29.0% improvement on user throughput or up to 14.1% improvement on GPU throughput over chunked prefill.
image

Limitation: Note that the context and generation workloads are relatively balanced in above cases (in terms of compute time). This feature (as well as any other disaggregation schemes) won't show much benefit if the workload is dominated by context or generation.

Test Coverage

Please suggest appropriate test cases to guard the feature.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Nov 8, 2025

📝 Walkthrough

Walkthrough

This PR introduces softmax-disaggregated context processing by adding configuration structures, a separate context model engine, disaggregated schedulers, and context-aware request filtering to enable independent execution of context and generation phases across multiple engines.

Changes

Cohort / File(s) Summary
Configuration & API
tensorrt_llm/llmapi/llm_args.py, tensorrt_llm/llmapi/__init__.py, tensorrt_llm/_torch/virtual_memory.py
Added SmDisaggConfig model with context percentage, max tokens, and max batch size fields; added sm_disagg_config optional field to TorchLlmArgs with validation; exported SmDisaggConfig from llmapi package; added MODEL_ENGINE_CTX enum member to ExecutorMemoryType.
Request Filtering & Utilities
tensorrt_llm/_torch/pyexecutor/llm_request.py
Added get_context_requests() and get_generation_requests() helper functions to filter requests by is_context_init_state.
Request Queue
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Extended ExecutorRequestQueue with is_sm_disagg flag; updated fetch_new_requests() signature to accept num_active_requests_on_engine; added SM-disaggregation path via _fetch_new_requests_sm_disagg().
Scheduler
tensorrt_llm/_torch/pyexecutor/scheduler.py
Imported get_context_requests; introduced SmDisaggCtxScheduler class that composes capacity and micro-batch schedulers to filter and schedule context requests separately.
Model Engine & Loader
tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/pyexecutor/model_loader.py
Added is_sm_disagg_ctx_phase and weight_sharing_model parameters to PyTorchModelEngine; added weight_sharing_model parameter to ModelLoader with state-dict loading logic; extended input preparation logic to gate first-step handling for SM-disaggregated phases.
Executor Creation & Initialization
tensorrt_llm/_torch/pyexecutor/_util.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Extended create_py_executor_instance() signature with ctx_model_engine and sm_disagg_config; added conditional logic to configure SmDisaggCtxScheduler when disaggregation is enabled; added compatibility checks; propagated context engine through model and executor construction.
Core Executor
tensorrt_llm/_torch/pyexecutor/py_executor.py
Added ctx_scheduler and ctx_model_engine parameters to PyExecutor; extended warmup and profiling to iterate over multiple engines; enhanced event loops with dedicated SM-disaggregation paths and synchronization primitives (locks, condition variables); updated _profiler(), _forward_step(), and _schedule() to accept optional engine/scheduler parameters; added phase-aware profiling output.

Sequence Diagram(s)

sequenceDiagram
    participant Client
    participant Executor as PyExecutor
    participant ReqQueue as RequestQueue
    participant MainSched as Scheduler
    participant SmSched as SmDisaggCtxScheduler
    participant CtxEngine as CtxModelEngine
    participant MainEngine as MainModelEngine

    Client->>Executor: enqueue_requests(requests)
    Executor->>ReqQueue: fetch_new_requests(num_active_requests_on_engine)
    
    alt SM Disaggregation Enabled
        ReqQueue->>ReqQueue: _fetch_new_requests_sm_disagg()
        ReqQueue->>MainSched: schedule_request()
        MainSched->>SmSched: filter & process via SmDisaggCtxScheduler
        SmSched->>SmSched: split context vs generation
        Note over SmSched: Context phase
        SmSched->>CtxEngine: forward context requests
        Note over SmSched: Generation phase
        SmSched->>MainEngine: forward generation requests
        SmSched->>ReqQueue: return merged results
    else Standard Path
        ReqQueue->>MainSched: schedule_request()
        MainSched->>MainEngine: forward all requests
    end

    ReqQueue-->>Executor: scheduled requests
    Executor-->>Client: await_responses()
Loading
sequenceDiagram
    participant Warmup
    participant Engine1 as ModelEngine
    participant Engine2 as CtxModelEngine
    participant Engine3 as DraftEngine
    participant Prof as Profiler

    Warmup->>Engine1: set_is_warmup(True)
    Warmup->>Engine2: set_is_warmup(True)
    Warmup->>Engine3: set_is_warmup(True)
    
    rect rgb(200, 220, 240)
    Note over Prof: Profile Main Engine
    Warmup->>Engine1: forward_step()
    Engine1->>Prof: record metrics
    end
    
    rect rgb(220, 240, 200)
    Note over Prof: Profile Context Engine
    Warmup->>Engine2: forward_step()
    Engine2->>Prof: record metrics
    end
    
    rect rgb(240, 220, 200)
    Note over Prof: Profile Draft Engine
    Warmup->>Engine3: forward_step()
    Engine3->>Prof: record metrics
    end

    Warmup->>Engine1: set_is_warmup(False)
    Warmup->>Engine2: set_is_warmup(False)
    Warmup->>Engine3: set_is_warmup(False)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Areas requiring extra attention:

  • py_executor.py: Substantial additions to executor initialization, warmup, and event loops with new synchronization primitives (locks, condition variables); careful review needed for deadlock risks and race conditions in context/generation coordination.
  • SmDisaggCtxScheduler: New scheduler composition logic that filters and splits requests; verify correct handling of context vs. generation request routing and state management.
  • py_executor_creator.py: Complex propagation of context engines through model construction and memory management; validate SM-disaggregation compatibility checks and weight-sharing initialization.
  • Multi-engine state management: Multiple forward paths now route through different engines (main, context, draft); ensure iter counters, iter_states, and profiling are correctly scoped per-engine.
  • Synchronization logic: New condition variables and locks for context/generation phase coordination; review for proper acquire/release patterns and deadlock avoidance.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.08% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title '[None][feat] Add SM-level disaggregation support' accurately describes the main feature being introduced and follows the repository's naming conventions.
Description check ✅ Passed The PR description includes background context, design overview, performance results, and acknowledges limitations. However, the 'Test Coverage' section asks for suggestions rather than providing actual test cases.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)

1-1: Add NVIDIA Apache-2.0 header (2025).

All source files must start with the NVIDIA Apache-2.0 header per coding guidelines.

Apply at top of file:

+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #     http://www.apache.org/licenses/LICENSE-2.0
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
tensorrt_llm/llmapi/llm_args.py (2)

1-1: Add NVIDIA Apache-2.0 header (2025).

Please prepend the standard NVIDIA Apache-2.0 header.

+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #     http://www.apache.org/licenses/LICENSE-2.0
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.

2843-2848: Message nit: match the check.

The validator allows 0; error message says “greater than 0”. Suggest “>= 0”.

-        if self.batch_wait_timeout_ms < 0:
-            raise ValueError("batch_wait_timeout_ms must be greater than 0")
+        if self.batch_wait_timeout_ms < 0:
+            raise ValueError("batch_wait_timeout_ms must be >= 0")
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)

1-1: Add NVIDIA Apache-2.0 header (2025).

Please prepend the standard header.

+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License at
+ #     http://www.apache.org/licenses/LICENSE-2.0
+ # Unless required by applicable law or agreed to in writing, software
+ # distributed under the License is distributed on an "AS IS" BASIS,
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ # See the License for the specific language governing permissions and
+ # limitations under the License.
🧹 Nitpick comments (6)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)

53-55: SM-disagg wiring looks fine; document semantics.

Constructor and self.is_sm_disagg flag are clear. Please document in the class docstring what “SM-disaggregation mode” implies for fetch capacity and scheduling so future readers don’t confuse it with network disaggregation.

Also applies to: 64-64

tensorrt_llm/llmapi/llm_args.py (2)

318-335: SmDisaggConfig: clarify bounds and behavior in docstring.

Suggest noting valid range for context_sm_percent and that non-positive ctx limits inherit generation limits. No functional change needed.


2888-2900: Support YAML ingestion for sm_disagg_config.

Add SmDisaggConfig to field_mapping so dicts in extra_llm_api_options are parsed consistently (mirrors other nested configs).

     field_mapping = {
         "quant_config": QuantConfig,
         "calib_config": CalibConfig,
         "build_config": BuildConfig,
         "decoding_config": DecodingConfig,
         "enable_build_cache": BuildCacheConfig,
         "speculative_config": DecodingBaseConfig,
         "lora_config": LoraConfig,
         "moe_config": MoeConfig,
         "attention_dp_config": AttentionDpConfig,
         "sparse_attention_config": BaseSparseAttentionConfig,
+        "sm_disagg_config": SmDisaggConfig,
     }
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

1363-1371: Operator precedence: parenthesize ‘and’ for clarity (RUF021).

Make the new SM-disagg condition explicit.

-            if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
-                    or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
+            if (next_draft_tokens_device is None or request.is_dummy
+                    or request.py_batch_idx is None
+                    or (self.sm_disagg_enabled
+                        and request.max_num_generated_tokens == 0)):

1470-1472: Operator precedence: parenthesize ‘and’ for clarity (RUF021).

Same as above in generation path.

-                if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
-                        or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
+                if (new_tokens_device is None or request.is_dummy
+                        or request.py_batch_idx is None
+                        or (self.sm_disagg_enabled
+                            and request.max_num_generated_tokens == 0)):
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

352-363: Consider extracting validation logic to a helper function.

The validation checks are correct but could improve maintainability by extracting to a helper function like _validate_sm_disagg_config(llm_args). This would make the main flow cleaner and the validation logic more testable.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6ff82ea and 3e96074.

📒 Files selected for processing (11)
  • tensorrt_llm/_torch/pyexecutor/_util.py (5 hunks)
  • tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/llm_request.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_engine.py (6 hunks)
  • tensorrt_llm/_torch/pyexecutor/model_loader.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py (24 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (7 hunks)
  • tensorrt_llm/_torch/pyexecutor/scheduler.py (2 hunks)
  • tensorrt_llm/_torch/virtual_memory.py (1 hunks)
  • tensorrt_llm/llmapi/__init__.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (3)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/virtual_memory.py
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/scheduler.py
  • tensorrt_llm/llmapi/__init__.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/virtual_memory.py
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/scheduler.py
  • tensorrt_llm/llmapi/__init__.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/virtual_memory.py
  • tensorrt_llm/_torch/pyexecutor/llm_request.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/scheduler.py
  • tensorrt_llm/llmapi/__init__.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
🧠 Learnings (9)
📚 Learning: 2025-07-22T09:22:14.726Z
Learnt from: yechank-nvidia
Repo: NVIDIA/TensorRT-LLM PR: 6254
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:1201-1204
Timestamp: 2025-07-22T09:22:14.726Z
Learning: In TensorRT-LLM's multimodal processing pipeline, shared tensor recovery using `from_shared_tensor()` is only needed during the context phase. Generation requests reuse the already-recovered tensor data and only need to call `strip_for_generation()` to remove unnecessary multimodal data while preserving the recovered tensors. This avoids redundant tensor recovery operations during generation.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/llm_request.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which can contain default `cuda_graph_config` values, so `llm_args` may already have this config before the extra options processing.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/llmapi/__init__.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-26T09:37:10.463Z
Learnt from: jiaganc
Repo: NVIDIA/TensorRT-LLM PR: 7031
File: tensorrt_llm/bench/dataclasses/configuration.py:90-104
Timestamp: 2025-08-26T09:37:10.463Z
Learning: In TensorRT-LLM's bench configuration, the `get_pytorch_perf_config()` method returns `self.pytorch_config` which is a Dict[str, Any] that can contain default values including `cuda_graph_config`, making the fallback `llm_args["cuda_graph_config"]` safe to use.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/llmapi/__init__.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-14T15:38:01.771Z
Learnt from: MatthiasKohl
Repo: NVIDIA/TensorRT-LLM PR: 6904
File: cpp/tensorrt_llm/pybind/thop/bindings.cpp:55-57
Timestamp: 2025-08-14T15:38:01.771Z
Learning: In TensorRT-LLM Python bindings, tensor parameter collections like mla_tensor_params and spec_decoding_tensor_params are kept as required parameters without defaults to maintain API consistency, even when it might affect backward compatibility.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-09-29T15:14:28.503Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 8063
File: tensorrt_llm/lora_manager.py:1080-1112
Timestamp: 2025-09-29T15:14:28.503Z
Learning: In tensorrt_llm/lora_manager.py, when calculating part_sizes for attn_qkv fused LoRA modules, the sizes are correctly multiplied by tp_size because model_config.num_heads and model_config.num_kv_heads are already divided by tp_size (per-TP-rank values), so multiplication is needed to get the original full concatenated dimension size. The interleave_fused_lora_weights_for_tp function provides proper validation.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-08-19T12:45:11.997Z
Learnt from: amitz-nv
Repo: NVIDIA/TensorRT-LLM PR: 7033
File: tensorrt_llm/_torch/pyexecutor/model_engine.py:0-0
Timestamp: 2025-08-19T12:45:11.997Z
Learning: In tensorrt_llm/_torch/pyexecutor/model_engine.py, DoRA (Delta Orthogonal Rank Adaptation) functionality was removed from the PyTorch flow to eliminate issues with inverted DoRA detection logic. The original is_dora condition was checking if scaling_vec_pointer == 0, which was potentially incorrect.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-08-28T10:22:02.288Z
Learnt from: ixlmar
Repo: NVIDIA/TensorRT-LLM PR: 7294
File: tensorrt_llm/_torch/pyexecutor/sampler.py:1191-1197
Timestamp: 2025-08-28T10:22:02.288Z
Learning: In tensorrt_llm/_torch/pyexecutor/sampler.py, the object identity comparison `softmax_req_indices is not group_req_indices_cuda` on line ~1191 is intentional and used as an optimization to determine whether to reuse an existing indexer or create a new one, based on which code path was taken during tensor assignment.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
📚 Learning: 2025-08-26T06:07:02.166Z
Learnt from: shaharmor98
Repo: NVIDIA/TensorRT-LLM PR: 7231
File: tensorrt_llm/_torch/pyexecutor/_util.py:504-509
Timestamp: 2025-08-26T06:07:02.166Z
Learning: In tensorrt_llm/_torch/pyexecutor/_util.py, when calling model_engine.set_lora_model_config(), pass model_binding_config.mlp_hidden_size directly without multiplying by mapping.tp_size, as the mlp_hidden_size from get_bindings_model_config() is already the per-TP rank value needed for LoRA weight packaging.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_loader.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
📚 Learning: 2025-09-03T13:16:06.824Z
Learnt from: nvpohanh
Repo: NVIDIA/TensorRT-LLM PR: 7478
File: tensorrt_llm/_torch/models/modeling_llama.py:1315-1315
Timestamp: 2025-09-03T13:16:06.824Z
Learning: The Llama4VisionEncoder.load_weights method signature is `def load_weights(self, weights: Dict)` and should not be confused with Llama4ForConditionalGeneration.load_weights which has a different signature including weight_mapper parameter.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/model_loader.py
🧬 Code graph analysis (9)
tensorrt_llm/_torch/pyexecutor/_util.py (4)
tensorrt_llm/llmapi/llm_args.py (1)
  • SmDisaggConfig (318-338)
tensorrt_llm/_torch/pyexecutor/scheduler.py (4)
  • SimpleScheduler (198-218)
  • SmDisaggCtxScheduler (221-243)
  • BindCapacityScheduler (72-99)
  • BindMicroBatchScheduler (171-195)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
  • PyTorchModelEngine (128-2549)
tensorrt_llm/mapping.py (1)
  • has_pp (254-255)
tensorrt_llm/_torch/pyexecutor/scheduler.py (2)
tensorrt_llm/_torch/pyexecutor/llm_request.py (2)
  • LlmRequest (423-643)
  • get_context_requests (802-803)
tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py (1)
  • SchedulerOutput (76-81)
tensorrt_llm/llmapi/__init__.py (1)
tensorrt_llm/llmapi/llm_args.py (2)
  • SmDisaggConfig (318-338)
  • TorchLlmArgs (2427-2880)
tensorrt_llm/_torch/pyexecutor/py_executor.py (5)
tensorrt_llm/_torch/pyexecutor/llm_request.py (5)
  • get_context_requests (802-803)
  • get_draft_token_length (788-799)
  • get_generation_requests (806-807)
  • get (129-141)
  • LlmRequest (423-643)
tensorrt_llm/_torch/pyexecutor/scheduler.py (7)
  • RequestScheduler (44-55)
  • schedule_request (47-55)
  • schedule_request (61-69)
  • schedule_request (95-99)
  • schedule_request (112-153)
  • schedule_request (206-218)
  • schedule_request (229-243)
tensorrt_llm/_torch/pyexecutor/model_engine.py (3)
  • ModelEngine (69-92)
  • forward (76-84)
  • forward (2286-2387)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
  • fetch_new_requests (337-347)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (4)
  • prepare_resources (81-82)
  • prepare_resources (407-447)
  • prepare_resources (1307-1310)
  • prepare_resources (1432-1448)
tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
  • LlmRequest (423-643)
tensorrt_llm/_torch/pyexecutor/model_loader.py (1)
examples/models/core/enc_dec/convert_checkpoint.py (1)
  • state_dict (1629-1630)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)
  • drafting_loop_wrapper (394-400)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
  • is_dummy (611-612)
tensorrt_llm/llmapi/llm_args.py (4)
tensorrt_llm/builder.py (1)
  • default (45-50)
tensorrt_llm/models/modeling_utils.py (3)
  • from_dict (253-263)
  • from_dict (325-334)
  • from_dict (487-492)
tests/unittest/api_stability/api_stability_core.py (3)
  • from_dict (116-123)
  • from_dict (172-178)
  • from_dict (319-328)
tensorrt_llm/mapping.py (1)
  • from_dict (314-315)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (4)
tensorrt_llm/_torch/virtual_memory.py (1)
  • ExecutorMemoryType (70-82)
tensorrt_llm/llmapi/llm_args.py (4)
  • parallel_config (1775-1776)
  • world_size (372-373)
  • world_size (382-386)
  • CapacitySchedulerPolicy (1077-1083)
tensorrt_llm/_torch/pyexecutor/model_engine.py (1)
  • PyTorchModelEngine (128-2549)
tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py (1)
  • attn_metadata (131-132)
🪛 Ruff (0.14.3)
tensorrt_llm/_torch/pyexecutor/model_engine.py

1365-1365: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)


1471-1471: Parenthesize a and b expressions when chaining and and or together, to make the precedence clear

Parenthesize the and subexpression

(RUF021)

tensorrt_llm/llmapi/llm_args.py

2833-2835: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

354-356: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (16)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)

205-206: Weight sharing passthrough: LGTM.

Passing weight_sharing_model into ModelLoader aligns with shared-weights engines.


154-157: Review comment references incorrect line numbers; underlying concern is unfounded.

The review points to lines 154-157 and 175, which do not contain max_num_generated_tokens access. Lines 154-157 initialize SM disagg config values, and line 175 checks sm_disagg_enabled (a boolean flag). The actual max_num_generated_tokens access occurs at lines 1365 and 1471, where it is accessed directly without getattr().

Evidence shows the attribute is guaranteed to exist:

  • Accessed directly throughout the codebase (lines 1365, 1471, plus test cases at 367, 379) without errors or defensive checks
  • Inherited from C++ bindings (tensorrt_llm.bindings.internal.batch_manager.LlmRequest) where it is always initialized
  • Tests confirm consistent availability

The existing code pattern validates that no getattr() wrapper is needed.

Likely an incorrect or invalid review comment.

tensorrt_llm/llmapi/llm_args.py (1)

2464-2468: No action needed—SmDisaggConfig is already properly exported.

The verification confirms SmDisaggConfig is imported on line 17 and added to the __all__ export list on line 65, making it available for downstream code to import directly from tensorrt_llm/llmapi.

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py (1)

338-348: All call sites verified—no issues found.

The single call site at py_executor.py:1819 correctly passes both required parameters (activate_requests and num_active_requests_on_engine) with matching types and order. No runtime errors will occur from missing parameters.

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (4)

81-104: LGTM! Memory monitoring support for context model is well-integrated.

The additions for MODEL_ENGINE_CTX follow the existing patterns and provide helpful tuning guidance consistent with the main model engine.


474-476: LGTM! Safe iteration with proper None checks.

The loop correctly checks for None before accessing engine attributes.


666-667: LGTM! Consistent propagation of SM-disagg parameters.

Both calls to create_py_executor_instance correctly pass ctx_model_engine and sm_disagg_config.

Also applies to: 725-726


696-702: LGTM! Proper resource cleanup with safe guards.

The loop correctly handles None engines and checks both attn_metadata and cuda_graph_runner.enabled before releasing resources.

tensorrt_llm/_torch/pyexecutor/py_executor.py (8)

50-51: LGTM! Clean imports and well-documented configuration.

The new imports and environment variable follow existing conventions and are clearly documented.

Also applies to: 61-63


124-125: LGTM! Proper initialization with correct synchronization primitives.

The new parameters, fields, and synchronization primitives (lock and condition variables) are correctly initialized. Using the same lock (sm_disagg_lock) for both condition variables is the right approach for coordinating the two executor loops.

Also applies to: 164-168, 186-187, 217-219


254-260: LGTM! Consistent multi-engine handling with proper guards.

The warmup, property setter, and cleanup operations correctly iterate over all engines with appropriate None checks.

Also applies to: 373-378, 473-474


537-647: LGTM! Well-designed profiler enhancements for multi-engine support.

The _profiler context manager is properly extended to support:

  • Multiple engines via model_engine parameter with safe default
  • Per-stream profiling via stream parameter
  • Phase identification via phase_name for clearer logs
  • Profiling toggle via enable_profiler parameter

The changes are backward compatible and follow the existing patterns.


1513-1694: Verify synchronization correctness under edge cases.

The synchronization between context and generation loops uses condition variables correctly for the common case:

  • Context waits on gen_request_cv when it has pending work but can't schedule (line 1555)
  • Generation waits on ctx_request_cv when it has no work (line 1641)
  • Each notifies the other after processing

However, verify behavior in these edge cases:

  1. Startup: When both loops start with no scheduled work, does generation block indefinitely on ctx_request_cv? The context loop will continue and eventually fetch requests, but there's a window where generation is blocked.

  2. Resource starvation: If context loop can't schedule due to resource constraints (line 1554 condition is true), and generation is also blocked, ensure one will eventually make progress.

Consider adding a timeout to the condition variable waits or additional logging to help diagnose potential hangs in production.

To verify the synchronization logic is sound, you could add assertion checks or logging:

# In context loop before waiting
if scheduled_batch.batch_size == 0 and (len(ctx_requests) > 0 or self.executor_request_queue.get_waiting_queue_size() > 0):
    logger.debug(f"Context loop waiting for generation to finish. Pending context requests: {len(ctx_requests)}")
    self.gen_request_cv.wait()

# In generation loop before waiting  
if scheduled_batch.batch_size == 0:
    logger.debug(f"Generation loop waiting for context to finish. Active requests: {len(self.active_requests)}")
    self.ctx_request_cv.wait()

1814-1820: LGTM! Correct capacity calculation for SM-disaggregation.

When ctx_model_engine is present, the calculation correctly counts only context requests since that's what's relevant for the context engine's capacity. This aligns with the SM-disaggregation design where context and generation phases have separate capacity limits.


1914-1922: LGTM! Flexible scheduling interface with safe defaults.

The optional scheduler and active_requests parameters enable disaggregated scheduling while maintaining backward compatibility through sensible defaults.


2154-2171: LGTM! Flexible forward step supporting multiple engines.

The optional model_engine parameter enables forwarding through either the main or context engine while preserving backward compatibility. The implementation correctly defaults to self.model_engine and uses the provided engine consistently throughout.

Comment on lines +349 to +366
def _fetch_new_requests_sm_disagg(
self, num_active_requests: int,
num_active_requests_on_engine: int) -> List[LlmRequest]:
"""Handle SM-level disaggregation request fetching."""
total_max_num_active_requests = (self.max_num_active_requests +
num_active_requests -
num_active_requests_on_engine)

# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
num_active_requests_on_engine,
total_max_num_active_requests,
enable_attention_dp=False)

# Merge requests and add to active list
merged_requests = self._merge_requests(new_requests)
return merged_requests

Copy link
Contributor

@coderabbitai coderabbitai bot Nov 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Bug: SM-disagg capacity calculation can oversubscribe total active requests.

As written, total_max_num_active_requests = max + num_active_requests - num_active_requests_on_engine allows this engine to claim more than the global remaining capacity if num_active_requests represents a cross-engine total. The safe pattern is:

  • remaining = max(0, max_num_active_requests - total_active_across_both_engines)
  • allow this engine to fetch at most ‘remaining’ new requests: set total_max_for_this_engine = current_on_this_engine + remaining

This keeps the combined active count bounded by max_num_active_requests.

Apply:

-    def _fetch_new_requests_sm_disagg(
-            self, num_active_requests: int,
-            num_active_requests_on_engine: int) -> List[LlmRequest]:
+    def _fetch_new_requests_sm_disagg(
+            self, num_active_requests: int,
+            num_active_requests_on_engine: int) -> List[LlmRequest]:
         """Handle SM-level disaggregation request fetching."""
-        total_max_num_active_requests = (self.max_num_active_requests +
-                                         num_active_requests -
-                                         num_active_requests_on_engine)
+        # num_active_requests is expected to be the cross-engine total when available.
+        # Compute remaining global slots and cap this engine accordingly.
+        remaining_global_slots = max(
+            0, self.max_num_active_requests - num_active_requests)
+        total_max_num_active_requests = (num_active_requests_on_engine +
+                                         remaining_global_slots)
 
         # fetch and process requests into waiting queue
         new_requests = self._fetch_and_process_requests(
-            num_active_requests_on_engine,
-            total_max_num_active_requests,
-            enable_attention_dp=False)
+            num_active_requests_on_engine,
+            total_max_num_active_requests,
+            enable_attention_dp=False)
 
         # Merge requests and add to active list
         merged_requests = self._merge_requests(new_requests)
         return merged_requests

If num_active_requests actually equals num_active_requests_on_engine at call sites today, this change is a no-op; it future-proofs the logic when passing cross-engine totals.

Confirm the intended meaning of num_active_requests (engine-local vs cross-engine) at call sites.


I'll verify this review comment by examining the call sites and parameter semantics to understand whether the capacity calculation bug exists.


🏁 Script executed:

#!/bin/bash
# Find all call sites of _fetch_new_requests_sm_disagg
echo "=== Call sites of _fetch_new_requests_sm_disagg ==="
rg -n "_fetch_new_requests_sm_disagg" --type=py -C 5

echo ""
echo "=== Full method implementation context ==="
# Get the file and surrounding context
fd -e py "executor_request_queue.py" --exec head -n 400 {} \; | tail -n 100

Length of output: 5320


I need to understand the semantics of activate_requests and who calls fetch_new_requests. Let me investigate further.


🏁 Script executed:

#!/bin/bash
# Find all callers of fetch_new_requests
echo "=== Callers of fetch_new_requests ==="
rg -n "\.fetch_new_requests\(" --type=py -B 3 -A 3

echo ""
echo "=== Class definition and __init__ ==="
rg -n "class ExecutorRequestQueue" --type=py -A 30

echo ""
echo "=== Method _fetch_and_process_requests signature ==="
rg -n "def _fetch_and_process_requests" --type=py -B 2 -A 15

Length of output: 6186


I need to understand the SM-disagg semantics and what get_context_requests does to confirm the bug:


🏁 Script executed:

#!/bin/bash
# Find get_context_requests definition and usage
echo "=== get_context_requests definition ==="
rg -n "def get_context_requests" --type=py -A 10

echo ""
echo "=== SM-disagg context and comments ==="
rg -n "sm_disagg\|SM.disagg" --type=py -B 2 -A 2

echo ""
echo "=== Complete method _fetch_new_requests_sm_disagg ==="
rg -n "_fetch_new_requests_sm_disagg" --type=py -A 25 | head -n 50

Length of output: 4249


Confirm the SM-disagg capacity calculation fix is necessary.

Based on verification:

  • num_active_requests passed to _fetch_new_requests_sm_disagg equals len(self.active_requests) — the total active requests across all phases.
  • num_active_requests_on_engine equals the count of context-phase requests only (subset obtained via get_context_requests).
  • The current formula max + num_active_requests - num_active_requests_on_engine can exceed the global limit. For example: with max=100, total_active=80, context_only=30, the formula yields 150.
  • The proposed fix correctly computes remaining global capacity and ensures the engine cannot claim more than its fair share.
    def _fetch_new_requests_sm_disagg(
            self, num_active_requests: int,
            num_active_requests_on_engine: int) -> List[LlmRequest]:
         """Handle SM-level disaggregation request fetching."""
-        total_max_num_active_requests = (self.max_num_active_requests +
-                                         num_active_requests -
-                                         num_active_requests_on_engine)
+        # num_active_requests is the total active across all phases.
+        # Compute remaining global slots and cap this engine accordingly.
+        remaining_global_slots = max(
+            0, self.max_num_active_requests - num_active_requests)
+        total_max_num_active_requests = (num_active_requests_on_engine +
+                                         remaining_global_slots)
 
         # fetch and process requests into waiting queue
         new_requests = self._fetch_and_process_requests(
             num_active_requests_on_engine,
             total_max_num_active_requests,
             enable_attention_dp=False)
 
         # Merge requests and add to active list
         merged_requests = self._merge_requests(new_requests)
         return merged_requests
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/executor_request_queue.py around lines 349 to
366, the current SM-disagg capacity calculation total_max_num_active_requests =
(self.max_num_active_requests + num_active_requests -
num_active_requests_on_engine) can exceed the global limit; replace it with a
computation of remaining global capacity: remaining =
self.max_num_active_requests - (num_active_requests -
num_active_requests_on_engine); clamp remaining to >=0; then set
total_max_num_active_requests = num_active_requests_on_engine + remaining so the
engine only claims up to the remaining global capacity and the global max is
never exceeded.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like reasonable feedback.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pcastonguay, great! The suggested fix correctly addresses the capacity oversubscription issue by computing the remaining global slots and ensuring the engine-specific limit respects the global maximum.

The key improvement is:

remaining_global_slots = max(0, self.max_num_active_requests - num_active_requests)
total_max_num_active_requests = num_active_requests_on_engine + remaining_global_slots

This ensures that even when num_active_requests represents the cross-engine total, the combined active count will never exceed max_num_active_requests. Let me know if you need any clarification on the implementation! 🐰

Comment on lines +139 to 145
is_sm_disagg_ctx_phase: bool = False,
is_draft_model: bool = False,
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
torch.nn.Module]] = None,
model: Optional[torch.nn.Module] = None,
weight_sharing_model: Optional[torch.nn.Module] = None,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Constructor additions look good; assert config presence when ctx-phase is enabled.

Prevent AttributeError if is_sm_disagg_ctx_phase=True but sm_disagg_config is None.

         spec_config: Optional["DecodingBaseConfig"] = None,
         is_sm_disagg_ctx_phase: bool = False,
         is_draft_model: bool = False,
@@
         ) = llm_args.get_runtime_sizes()
-        if is_sm_disagg_ctx_phase:
-            max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
-            max_batch_size = llm_args.sm_disagg_config.context_max_batch_size
+        if is_sm_disagg_ctx_phase:
+            if llm_args.sm_disagg_config is None:
+                raise ValueError(
+                    "is_sm_disagg_ctx_phase=True requires sm_disagg_config"
+                )
+            max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
+            max_batch_size = llm_args.sm_disagg_config.context_max_batch_size

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 139 to 145, when
is_sm_disagg_ctx_phase=True the constructor may later access sm_disagg_config
and raise AttributeError if it's None; add an explicit check at construction
start that if is_sm_disagg_ctx_phase is True then sm_disagg_config is not None,
and raise a clear ValueError or use assert with a descriptive message indicating
sm_disagg_config is required for SM disaggregation context phase so callers get
an immediate, informative failure.

Comment on lines +312 to +316
if self.weight_sharing_model is not None:
model.load_state_dict(self.weight_sharing_model.state_dict(),
assign=True)
# Free up duplicate model weights allocated before weight sharing
torch.cuda.empty_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep shared weights on-device when assigning

state_dict() without keep_vars=True produces detached CPU tensors. With assign=True, those CPU tensors replace the module’s CUDA parameters, so this branch forces the newly built engine to run with CPU weights and immediately triggers device-mismatch failures instead of sharing memory. Please grab the on-device Parameter objects before assigning.

-            model.load_state_dict(self.weight_sharing_model.state_dict(),
-                                      assign=True)
+            shared_state = self.weight_sharing_model.state_dict(
+                keep_vars=True)
+            model.load_state_dict(shared_state, assign=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.weight_sharing_model is not None:
model.load_state_dict(self.weight_sharing_model.state_dict(),
assign=True)
# Free up duplicate model weights allocated before weight sharing
torch.cuda.empty_cache()
if self.weight_sharing_model is not None:
shared_state = self.weight_sharing_model.state_dict(
keep_vars=True)
model.load_state_dict(shared_state, assign=True)
# Free up duplicate model weights allocated before weight sharing
torch.cuda.empty_cache()
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_loader.py around lines 312 to 316, the
code calls self.weight_sharing_model.state_dict() which returns detached CPU
tensors and then uses assign=True, causing CPU tensors to replace CUDA
parameters; instead obtain the on-device Parameter objects by calling
state_dict(keep_vars=True) (or otherwise capture the
weight_sharing_model.parameters()/buffers as Variables on their current device)
and pass that mapping into model.load_state_dict(..., assign=True); ensure any
torch.cuda.empty_cache() call happens after assignment if needed.

Comment on lines +352 to +380
if llm_args.sm_disagg_config is not None:
if llm_args.cache_transceiver_config is not None:
raise ValueError(
"SM-level disaggregation is not compatible with disaggregated serving."
)
if llm_args.parallel_config.world_size > 1:
raise NotImplementedError(
"SM-level disaggregation is not supported with parallelism.")
if scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT:
raise NotImplementedError(
"SM-level disaggregation is only supported with guaranteed no evict scheduler policy."
)

with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX,
RestoreMode.PINNED):
ctx_llm_args = copy.copy(llm_args)
ctx_llm_args.cuda_graph_config = None
ctx_model_engine = PyTorchModelEngine(
model_path=checkpoint_dir,
llm_args=ctx_llm_args,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=spec_config,
weight_sharing_model=model_engine.model,
)
else:
ctx_model_engine = None

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical: Missing is_sm_disagg_ctx_phase=True flag and potential config sharing issue.

Two critical issues:

  1. Missing phase flag: Line 369-376 should pass is_sm_disagg_ctx_phase=True to PyTorchModelEngine. According to model_engine.py (lines 127-210 in snippets), this flag controls whether to use sm_disagg_config.context_max_num_tokens and sm_disagg_config.context_max_batch_size. Without it, the context model will use the wrong capacity limits.

  2. Shallow copy risk: Line 368 uses copy.copy(llm_args) which creates a shallow copy. If llm_args contains mutable objects (e.g., nested configs), modifications to those objects will affect both llm_args and ctx_llm_args. While setting cuda_graph_config = None on line 369 is safe if it replaces a reference, other mutable fields could be problematic.

Apply this diff:

             ctx_llm_args = copy.copy(llm_args)
             ctx_llm_args.cuda_graph_config = None
             ctx_model_engine = PyTorchModelEngine(
                 model_path=checkpoint_dir,
                 llm_args=ctx_llm_args,
                 mapping=mapping,
                 attn_runtime_features=attn_runtime_features,
                 dist=dist,
                 spec_config=spec_config,
+                is_sm_disagg_ctx_phase=True,
                 weight_sharing_model=model_engine.model,
             )

Consider using copy.deepcopy(llm_args) instead of copy.copy(llm_args) at line 368 to avoid unintended config sharing between the main and context engines.

🧰 Tools
🪛 Ruff (0.14.3)

354-356: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py around lines 352 to 380:
the context model engine is created without the is_sm_disagg_ctx_phase=True flag
and uses a shallow copy of llm_args; change copy.copy(llm_args) to
copy.deepcopy(llm_args) to avoid shared mutable state, set
ctx_llm_args.cuda_graph_config = None as before, and pass
is_sm_disagg_ctx_phase=True into the PyTorchModelEngine constructor so the
context engine uses sm_disagg_config context limits.

Comment on lines +1707 to +1741
def split_device_green_ctx(self):
device = torch.device("cuda", self.device_id)
device_properties = torch.cuda.get_device_properties(device)
sm_count = device_properties.multi_processor_count
if device_properties.major >= 9:
sm_min = 8
sm_align = 8
else:
sm_min = 4 if device_properties.major == 8 else 2
sm_align = 2

from flashinfer import green_ctx

def split_device_green_ctx_aligned(sm_s1):
sm_s1 = round(sm_s1 / sm_align) * sm_align
sm_s1 = min(max(sm_s1, sm_min), sm_count - sm_min)
return green_ctx.split_device_green_ctx_by_sm_count(device, [sm_s1])

sm_ctx = round(sm_count * self.sm_disagg_ctx_sm_percent)
sm_gen = sm_count - sm_ctx
# Choose the split closer to user-specified percentage when sm_count is not divisible by sm_align
sm_ctx_dist = min(sm_ctx % sm_align, sm_align - (sm_ctx % sm_align))
sm_gen_dist = min(sm_gen % sm_align, sm_align - (sm_gen % sm_align))
if sm_gen_dist < sm_ctx_dist:
(stream_gen,
stream_ctx), (res_gen,
res_ctx) = split_device_green_ctx_aligned(sm_gen)
else:
(stream_ctx,
stream_gen), (res_ctx,
res_gen) = split_device_green_ctx_aligned(sm_ctx)
logger.info(
f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase."
)
return stream_ctx, stream_gen
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Move flashinfer import to module level for consistency and early failure detection.

Line 1718 imports flashinfer.green_ctx inside the split_device_green_ctx method. This is problematic because:

  1. Late failure: If flashinfer is not available or incompatible, the error won't surface until this method is called during execution, making debugging harder.
  2. Inconsistent style: Other imports in this file are at the module level.
  3. Repeated import overhead: Although Python caches imports, it's better to import once at module level.

Move the import to the top of the file:

 from tensorrt_llm._torch.pyexecutor.resource_manager import (
     ResourceManagerType, request_context)
+try:
+    from flashinfer import green_ctx
+except ImportError:
+    green_ctx = None
 from tensorrt_llm._utils import (customized_gc_thresholds, is_trace_enabled,
                                  mpi_disabled, nvtx_range, trace_func)

Then add a validation check in __init__ or at the start of split_device_green_ctx:

def split_device_green_ctx(self):
    if green_ctx is None:
        raise ImportError("flashinfer is required for SM-level disaggregation but is not available")
    device = torch.device("cuda", self.device_id)
    # ... rest of the method
🤖 Prompt for AI Agents
tensorrt_llm/_torch/pyexecutor/py_executor.py around lines 1707 to 1741: the
function currently imports flashinfer.green_ctx inside split_device_green_ctx
causing late failures and inconsistent style; move "from flashinfer import
green_ctx" to the module-level imports at top of the file, then add a validation
(in __init__ or at the start of split_device_green_ctx) that raises ImportError
with a clear message if green_ctx is None or missing so callers fail fast;
ensure no other behavior changes and remove the local import from this method.

Comment on lines +2825 to +2841
@model_validator(mode='after')
def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs':
"""Validate SM-level disaggregation configuration."""
if self.sm_disagg_config is None:
return self

config = self.sm_disagg_config
if not 0 < config.context_sm_percent < 1:
raise ValueError(
"sm_disagg_config.context_sm_percent must be in the range (0, 1)"
)
if config.context_max_num_tokens <= 0:
config.context_max_num_tokens = self.max_num_tokens
if config.context_max_batch_size <= 0:
config.context_max_batch_size = self.max_batch_size
return self

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Allow 1.0 for context_sm_percent and clamp ctx limits to runtime maxima.

  • The current check excludes 1.0; allowing full SM allocation to context is reasonable.
  • Also clamp context_max_num_tokens/batch_size to runtime maxima to avoid illegal overprovisioning.

Apply:

     def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs':
         """Validate SM-level disaggregation configuration."""
         if self.sm_disagg_config is None:
             return self
 
         config = self.sm_disagg_config
-        if not 0 < config.context_sm_percent < 1:
+        if not (0 < config.context_sm_percent <= 1):
             raise ValueError(
-                "sm_disagg_config.context_sm_percent must be in the range (0, 1)"
+                "sm_disagg_config.context_sm_percent must be in the range (0, 1]"
             )
         if config.context_max_num_tokens <= 0:
             config.context_max_num_tokens = self.max_num_tokens
         if config.context_max_batch_size <= 0:
             config.context_max_batch_size = self.max_batch_size
+        # Clamp to runtime maxima if provided
+        if self.max_num_tokens is not None and config.context_max_num_tokens is not None:
+            config.context_max_num_tokens = min(config.context_max_num_tokens,
+                                                self.max_num_tokens)
+        if self.max_batch_size is not None and config.context_max_batch_size is not None:
+            config.context_max_batch_size = min(config.context_max_batch_size,
+                                                self.max_batch_size)
         return self
🧰 Tools
🪛 Ruff (0.14.3)

2833-2835: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
tensorrt_llm/llmapi/llm_args.py lines 2825-2841: update the validator to allow
context_sm_percent == 1.0 (change the range check to require 0 <
context_sm_percent <= 1) and ensure context_max_num_tokens and
context_max_batch_size are clamped to runtime maxima — if each is <= 0 set it to
self.max_num_tokens / self.max_batch_size respectively, otherwise set it to
min(their value, self.max_num_tokens/self.max_batch_size) to prevent
overprovisioning.

@qiangxu1996
Copy link
Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23881 [ run ] triggered by Bot. Commit: 3e96074

@chang-l
Copy link
Collaborator

chang-l commented Nov 8, 2025

Does this PR support dynamically allocating SM resources between the ctx and gen engines based on the per-batch workload? If so, is there a cost model used to determine the SM partitioning?

@qiangxu1996
Copy link
Author

Does this PR support dynamically allocating SM resources between the ctx and gen engines based on the per-batch workload? If so, is there a cost model used to determine the SM partitioning?

No. A fixed SM ratio is passed through llm args. The fixed ratio seems to be good enough for the workloads we looked into (including OpenOrca which has varying ISL/OSL per request). So dynamic SM allocation is not planned unless there's a compelling use case.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #23881 [ run ] completed with state SUCCESS. Commit: 3e96074
/LLM/main/L0_MergeRequest_PR pipeline #17977 completed with status: 'FAILURE'

sm_min = 4 if device_properties.major == 8 else 2
sm_align = 2

from flashinfer import green_ctx
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though TRTLLM already integrates FlashInfer, is it a better idea to directly integrate the low-level green_context API provided by CUDA to reduce the abstraction level here?


self._kv_connector_terminate_requests()

def _executor_loop_sm_disagg_ctx(self, stream):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a large amount of duplicated code from the main executor loop. Could we refactor to avoid so much duplication? This duplicated code makes it very hard to maintain.

iter_stats=iter_stats,
iter_start_time=iter_start_time))

def _executor_loop_sm_disagg_gen_overlap(self, stream):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a large amount of duplicated code from the main executor loop with overlap. Could we refactor to avoid so much duplication? This duplicated code makes it very hard to maintain.

else:
return self._fetch_new_requests_attention_tp(len(activate_requests))

def _fetch_new_requests_sm_disagg(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should add unit tests for this new function.

new_requests = self._fetch_and_process_requests(
num_active_requests_on_engine,
total_max_num_active_requests,
enable_attention_dp=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we simply setting this to False? If we don't support attention dp, we should raise an error.

Comment on lines +349 to +366
def _fetch_new_requests_sm_disagg(
self, num_active_requests: int,
num_active_requests_on_engine: int) -> List[LlmRequest]:
"""Handle SM-level disaggregation request fetching."""
total_max_num_active_requests = (self.max_num_active_requests +
num_active_requests -
num_active_requests_on_engine)

# fetch and process requests into waiting queue
new_requests = self._fetch_and_process_requests(
num_active_requests_on_engine,
total_max_num_active_requests,
enable_attention_dp=False)

# Merge requests and add to active list
merged_requests = self._merge_requests(new_requests)
return merged_requests

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like reasonable feedback.

Comment on lines +217 to +219
self.sm_disagg_lock = threading.Lock()
self.ctx_request_cv = threading.Condition(self.sm_disagg_lock)
self.gen_request_cv = threading.Condition(self.sm_disagg_lock)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should those get initialized only if ctx_model_engine is not None?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will need to add integration/unit tests to validate the new code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants