Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 23 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
EagleDecodingConfig, KvCacheConfig,
MTPDecodingConfig, PeftCacheConfig,
SamplerType, SchedulerConfig,
SparseAttentionConfig,
SmDisaggConfig, SparseAttentionConfig,
SpeculativeConfig, TorchLlmArgs)
from tensorrt_llm.logger import logger
from tensorrt_llm.lora_helper import (LoraConfig,
Expand All @@ -29,7 +29,7 @@
from .guided_decoder import GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .llm_request import ExecutorResponse
from .llm_request import ExecutorResponse, LlmRequestState
from .mamba_cache_manager import MambaHybridCacheManager
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
Expand Down Expand Up @@ -665,6 +665,8 @@ def create_py_executor_instance(
max_batch_size: Optional[int] = None,
max_beam_width: Optional[int] = None,
max_num_tokens: Optional[int] = None,
ctx_model_engine: Optional[PyTorchModelEngine] = None,
sm_disagg_config: Optional[SmDisaggConfig] = None,
peft_cache_config: Optional[PeftCacheConfig] = None,
scheduler_config: Optional[SchedulerConfig] = None,
cache_transceiver_config: Optional[CacheTransceiverConfig] = None,
Expand Down Expand Up @@ -789,6 +791,23 @@ def create_py_executor_instance(
ctx_chunk_config)
scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)

if sm_disagg_config is not None:
scheduler_capacity += sm_disagg_config.context_max_batch_size * mapping.pp_size
capacity_scheduler = BindCapacityScheduler(
scheduler_capacity,
kv_cache_manager.impl if kv_cache_manager is not None else None,
peft_cache_manager.impl if peft_cache_manager is not None else None,
scheduler_config.capacity_scheduler_policy,
two_step_lookahead=mapping.has_pp())
mb_scheduler = BindMicroBatchScheduler(
sm_disagg_config.context_max_batch_size,
sm_disagg_config.context_max_num_tokens,
ctx_chunk_config,
no_schedule_after_state=LlmRequestState.GENERATION_IN_PROGRESS)
ctx_scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler)
else:
ctx_scheduler = None

config = model_engine.model.model_config.pretrained_config
attention_type = AttentionTypeCpp.MLA if is_mla(
config) else AttentionTypeCpp.DEFAULT
Expand All @@ -801,6 +820,8 @@ def create_py_executor_instance(
model_engine=model_engine,
sampler=sampler,
drafter=drafter,
ctx_scheduler=ctx_scheduler,
ctx_model_engine=ctx_model_engine,
dist=dist,
max_num_sequences=max_num_sequences,
disable_overlap_scheduler=llm_args.disable_overlap_scheduler,
Expand Down
21 changes: 15 additions & 6 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ def is_control_request(self):
class ExecutorRequestQueue:
"""Handles fetching and processing of new requests from the request queue."""

def __init__(self, dist: Distributed, enable_attention_dp: bool,
max_batch_size: int, max_beam_width: int,
max_num_active_requests: int, enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float):
def __init__(self,
dist: Distributed,
enable_attention_dp: bool,
max_batch_size: int,
max_beam_width: int,
max_num_active_requests: int,
enable_iter_perf_stats: bool,
batch_wait_timeout_ms: float,
is_sm_disagg: bool = False):
self.dist = dist
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
self.waiting_queue: deque[RequestQueueItem] = deque()
Expand All @@ -59,6 +64,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
self.max_batch_size = max_batch_size
self.max_beam_width = max_beam_width
self.max_num_active_requests = max_num_active_requests
self.is_sm_disagg = is_sm_disagg
self.enqueue_lock = threading.Lock()
self.next_request_id = max_batch_size
self.enable_iter_perf_stats = enable_iter_perf_stats
Expand Down Expand Up @@ -333,12 +339,15 @@ def _fetch_and_process_requests(

@nvtx_range("_fetch_new_requests")
def fetch_new_requests(
self, activate_requests: List[LlmRequest]) -> List[LlmRequest]:
self, activate_requests: List[LlmRequest],
num_active_requests_on_engine: int) -> List[LlmRequest]:

if self.enable_attention_dp:
return self._fetch_new_requests_attention_dp(activate_requests)
else:
return self._fetch_new_requests_attention_tp(len(activate_requests))
num_active_requests = num_active_requests_on_engine if self.is_sm_disagg else len(
activate_requests)
return self._fetch_new_requests_attention_tp(num_active_requests)

def _fetch_new_requests_attention_tp(
self, num_active_requests: int) -> List[LlmRequest]:
Expand Down
82 changes: 82 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/green_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import torch
from cuda.bindings import driver

from tensorrt_llm.runtime.generation import CUASSERT


def green_ctx_create_streams(res_list, device):
streams = []
for res in res_list:
desc = CUASSERT(driver.cuDevResourceGenerateDesc([res], 1))[0]
green_ctx = CUASSERT(
driver.cuGreenCtxCreate(
desc, device, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM
)
)[0]
stream = CUASSERT(
driver.cuGreenCtxStreamCreate(
green_ctx, driver.CUstream_flags.CU_STREAM_NON_BLOCKING, 0
)
)[0]
stream = torch.cuda.get_stream_from_external(stream, device)
streams.append(stream)
return streams


def green_ctx_split_percent(sm_percent: float, device_id: int = 0):
device = CUASSERT(driver.cuDeviceGet(device_id))[0]

res = CUASSERT(
driver.cuDeviceGetDevResource(device, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM)
)[0]
sm_count = res.sm.smCount

major = CUASSERT(
driver.cuDeviceGetAttribute(
driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device
)
)[0]
if major >= 9:
sm_min = 8
sm_align = 8
else:
sm_min = 4 if major == 8 else 2
sm_align = 2

def green_ctx_split_aligned(sm_g1):
sm_g1 = round(sm_g1 / sm_align) * sm_align
sm_g1 = min(max(sm_g1, sm_min), sm_count - sm_min)
result = CUASSERT(
driver.cuDevSmResourceSplitByCount(
1, # nbGroups
res,
0, # useFlags
sm_g1,
)
)
res_split = (result[0][0], result[2])
streams = green_ctx_create_streams(res_split, device)
return streams, res_split

sm_g1 = round(sm_count * sm_percent)
sm_g2 = sm_count - sm_g1
# Choose the split closer to sm_percent when sm_count is not divisible by sm_align
sm_g1_dist = min(sm_g1 % sm_align, sm_align - (sm_g1 % sm_align))
sm_g2_dist = min(sm_g2 % sm_align, sm_align - (sm_g2 % sm_align))
if sm_g1_dist <= sm_g2_dist:
(stream_g1, stream_g2), (res_g1, res_g2) = green_ctx_split_aligned(sm_g1)
else:
(stream_g2, stream_g1), (res_g2, res_g1) = green_ctx_split_aligned(sm_g2)
return (stream_g1, stream_g2), (res_g1, res_g2)
8 changes: 8 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,3 +797,11 @@ def get_draft_token_length(request: LlmRequest) -> int:
if request.py_draft_tokens is not None:
return len(request.py_draft_tokens)
return 0


def get_context_requests(requests: List[LlmRequest]):
return [req for req in requests if req.is_context_init_state]


def get_generation_requests(requests: List[LlmRequest]):
return [req for req in requests if not req.is_context_init_state]
19 changes: 15 additions & 4 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,12 @@ def __init__(
attn_runtime_features: Optional[AttentionRuntimeFeatures] = None,
dist: Optional[MPIDist] = None,
spec_config: Optional["DecodingBaseConfig"] = None,
is_sm_disagg_ctx_phase: bool = False,
is_draft_model: bool = False,
drafting_loop_wrapper: Optional[Callable[[torch.nn.Module],
torch.nn.Module]] = None,
model: Optional[torch.nn.Module] = None,
weight_sharing_model: Optional[torch.nn.Module] = None,
):
Comment on lines +138 to 144
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Constructor additions look good; assert config presence when ctx-phase is enabled.

Prevent AttributeError if is_sm_disagg_ctx_phase=True but sm_disagg_config is None.

         spec_config: Optional["DecodingBaseConfig"] = None,
         is_sm_disagg_ctx_phase: bool = False,
         is_draft_model: bool = False,
@@
         ) = llm_args.get_runtime_sizes()
-        if is_sm_disagg_ctx_phase:
-            max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
-            max_batch_size = llm_args.sm_disagg_config.context_max_batch_size
+        if is_sm_disagg_ctx_phase:
+            if llm_args.sm_disagg_config is None:
+                raise ValueError(
+                    "is_sm_disagg_ctx_phase=True requires sm_disagg_config"
+                )
+            max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
+            max_batch_size = llm_args.sm_disagg_config.context_max_batch_size

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_engine.py around lines 139 to 145, when
is_sm_disagg_ctx_phase=True the constructor may later access sm_disagg_config
and raise AttributeError if it's None; add an explicit check at construction
start that if is_sm_disagg_ctx_phase is True then sm_disagg_config is not None,
and raise a clear ValueError or use assert with a descriptive message indicating
sm_disagg_config is required for SM disaggregation context phase so callers get
an immediate, informative failure.

self.forward_pass_callable = None
self.ub_buffers = None
Expand All @@ -148,6 +150,9 @@ def __init__(
max_seq_len,
max_batch_size,
) = llm_args.get_runtime_sizes()
if is_sm_disagg_ctx_phase:
max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens
max_batch_size = llm_args.sm_disagg_config.context_max_batch_size

self.batch_size = max_batch_size
self.max_num_tokens = max_num_tokens
Expand All @@ -165,6 +170,7 @@ def __init__(
if dist is not None:
ExpertStatistic.create(self.dist.rank)
self.llm_args = llm_args
self.sm_disagg_enabled = llm_args.sm_disagg_config is not None
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0

Expand Down Expand Up @@ -195,6 +201,7 @@ def __init__(
max_num_tokens=self.max_num_tokens,
max_seq_len=self.max_seq_len,
lora_config=lora_config,
weight_sharing_model=weight_sharing_model,
)
self.model, moe_load_balancer = loader.load(
checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader)
Expand Down Expand Up @@ -1434,8 +1441,10 @@ def _prepare_tp_inputs(
# the request has no previous tensor:
# (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or
# (2) a dummy request; or
# (3) the first step in the generation server of disaggregated serving
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
# (3) the first step in the generation server of disaggregated serving; or
# (4) the first step in the generation phase of SM-level disaggregation
if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
# get token ids, including input token ids and draft token ids. For these dummy requests,
# no need to copy the token ids.
if not (request.is_attention_dp_dummy
Expand Down Expand Up @@ -1559,8 +1568,10 @@ def _prepare_tp_inputs(
# the request has no previous tensor:
# (1) new_tokens_device is None, which means overlap scheduler is disabled; or
# (2) a dummy request; or
# (3) the first step in the generation server of disaggregated serving
if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None:
# (3) the first step in the generation server of disaggregated serving; or
# (4) the first step in the generation phase of SM-level disaggregation
if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \
or self.sm_disagg_enabled and request.max_num_generated_tokens == 0:
# skip adding input_ids of CUDA graph dummy requests so that new_tokens_device
# can be aligned to the correct positions.
if not request.is_cuda_graph_dummy:
Expand Down
10 changes: 9 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def __init__(self,
sparse_attention_config: Optional["SparseAttentionConfig"],
max_num_tokens: int,
max_seq_len: Optional[int],
lora_config: Optional[LoraConfig] = None):
lora_config: Optional[LoraConfig] = None,
weight_sharing_model: Optional[torch.nn.Module] = None):
"""
Initializes the ModelLoader.

Expand All @@ -210,6 +211,7 @@ def __init__(self,
self.max_num_tokens = max_num_tokens
self.max_seq_len = max_seq_len
self.lora_config = lora_config
self.weight_sharing_model = weight_sharing_model

def load(
self,
Expand Down Expand Up @@ -307,6 +309,12 @@ def init_meta_tensor(t: torch.Tensor):
moe_load_balancer.finalize_model()
logger.info("moe_load_balancer finalize model done")

if self.weight_sharing_model is not None:
model.load_state_dict(self.weight_sharing_model.state_dict(),
assign=True)
# Free up duplicate model weights allocated before weight sharing
torch.cuda.empty_cache()
Comment on lines +312 to +316
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Keep shared weights on-device when assigning

state_dict() without keep_vars=True produces detached CPU tensors. With assign=True, those CPU tensors replace the module’s CUDA parameters, so this branch forces the newly built engine to run with CPU weights and immediately triggers device-mismatch failures instead of sharing memory. Please grab the on-device Parameter objects before assigning.

-            model.load_state_dict(self.weight_sharing_model.state_dict(),
-                                      assign=True)
+            shared_state = self.weight_sharing_model.state_dict(
+                keep_vars=True)
+            model.load_state_dict(shared_state, assign=True)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if self.weight_sharing_model is not None:
model.load_state_dict(self.weight_sharing_model.state_dict(),
assign=True)
# Free up duplicate model weights allocated before weight sharing
torch.cuda.empty_cache()
if self.weight_sharing_model is not None:
shared_state = self.weight_sharing_model.state_dict(
keep_vars=True)
model.load_state_dict(shared_state, assign=True)
# Free up duplicate model weights allocated before weight sharing
torch.cuda.empty_cache()
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/model_loader.py around lines 312 to 316, the
code calls self.weight_sharing_model.state_dict() which returns detached CPU
tensors and then uses assign=True, causing CPU tensors to replace CUDA
parameters; instead obtain the on-device Parameter objects by calling
state_dict(keep_vars=True) (or otherwise capture the
weight_sharing_model.parameters()/buffers as Variables on their current device)
and pass that mapping into model.load_state_dict(..., assign=True); ensure any
torch.cuda.empty_cache() call happens after assignment if needed.


torch.cuda.current_stream().synchronize()

return model, moe_load_balancer
Expand Down
Loading