diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 8da982aba2b..eabd9b22e8e 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -14,7 +14,7 @@ EagleDecodingConfig, KvCacheConfig, MTPDecodingConfig, PeftCacheConfig, SamplerType, SchedulerConfig, - SparseAttentionConfig, + SmDisaggConfig, SparseAttentionConfig, SpeculativeConfig, TorchLlmArgs) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import (LoraConfig, @@ -29,7 +29,7 @@ from .guided_decoder import GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver -from .llm_request import ExecutorResponse +from .llm_request import ExecutorResponse, LlmRequestState from .mamba_cache_manager import MambaHybridCacheManager from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor @@ -665,6 +665,8 @@ def create_py_executor_instance( max_batch_size: Optional[int] = None, max_beam_width: Optional[int] = None, max_num_tokens: Optional[int] = None, + ctx_model_engine: Optional[PyTorchModelEngine] = None, + sm_disagg_config: Optional[SmDisaggConfig] = None, peft_cache_config: Optional[PeftCacheConfig] = None, scheduler_config: Optional[SchedulerConfig] = None, cache_transceiver_config: Optional[CacheTransceiverConfig] = None, @@ -789,6 +791,23 @@ def create_py_executor_instance( ctx_chunk_config) scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) + if sm_disagg_config is not None: + scheduler_capacity += sm_disagg_config.context_max_batch_size * mapping.pp_size + capacity_scheduler = BindCapacityScheduler( + scheduler_capacity, + kv_cache_manager.impl if kv_cache_manager is not None else None, + peft_cache_manager.impl if peft_cache_manager is not None else None, + scheduler_config.capacity_scheduler_policy, + two_step_lookahead=mapping.has_pp()) + mb_scheduler = BindMicroBatchScheduler( + sm_disagg_config.context_max_batch_size, + sm_disagg_config.context_max_num_tokens, + ctx_chunk_config, + no_schedule_after_state=LlmRequestState.GENERATION_IN_PROGRESS) + ctx_scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) + else: + ctx_scheduler = None + config = model_engine.model.model_config.pretrained_config attention_type = AttentionTypeCpp.MLA if is_mla( config) else AttentionTypeCpp.DEFAULT @@ -801,6 +820,8 @@ def create_py_executor_instance( model_engine=model_engine, sampler=sampler, drafter=drafter, + ctx_scheduler=ctx_scheduler, + ctx_model_engine=ctx_model_engine, dist=dist, max_num_sequences=max_num_sequences, disable_overlap_scheduler=llm_args.disable_overlap_scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 8c506da6b54..2d51561de05 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -47,10 +47,15 @@ def is_control_request(self): class ExecutorRequestQueue: """Handles fetching and processing of new requests from the request queue.""" - def __init__(self, dist: Distributed, enable_attention_dp: bool, - max_batch_size: int, max_beam_width: int, - max_num_active_requests: int, enable_iter_perf_stats: bool, - batch_wait_timeout_ms: float): + def __init__(self, + dist: Distributed, + enable_attention_dp: bool, + max_batch_size: int, + max_beam_width: int, + max_num_active_requests: int, + enable_iter_perf_stats: bool, + batch_wait_timeout_ms: float, + is_sm_disagg: bool = False): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.waiting_queue: deque[RequestQueueItem] = deque() @@ -59,6 +64,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool, self.max_batch_size = max_batch_size self.max_beam_width = max_beam_width self.max_num_active_requests = max_num_active_requests + self.is_sm_disagg = is_sm_disagg self.enqueue_lock = threading.Lock() self.next_request_id = max_batch_size self.enable_iter_perf_stats = enable_iter_perf_stats @@ -333,12 +339,15 @@ def _fetch_and_process_requests( @nvtx_range("_fetch_new_requests") def fetch_new_requests( - self, activate_requests: List[LlmRequest]) -> List[LlmRequest]: + self, activate_requests: List[LlmRequest], + num_active_requests_on_engine: int) -> List[LlmRequest]: if self.enable_attention_dp: return self._fetch_new_requests_attention_dp(activate_requests) else: - return self._fetch_new_requests_attention_tp(len(activate_requests)) + num_active_requests = num_active_requests_on_engine if self.is_sm_disagg else len( + activate_requests) + return self._fetch_new_requests_attention_tp(num_active_requests) def _fetch_new_requests_attention_tp( self, num_active_requests: int) -> List[LlmRequest]: diff --git a/tensorrt_llm/_torch/pyexecutor/green_ctx.py b/tensorrt_llm/_torch/pyexecutor/green_ctx.py new file mode 100644 index 00000000000..68a074d42cb --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/green_ctx.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from cuda.bindings import driver + +from tensorrt_llm.runtime.generation import CUASSERT + + +def green_ctx_create_streams(res_list, device): + streams = [] + for res in res_list: + desc = CUASSERT(driver.cuDevResourceGenerateDesc([res], 1))[0] + green_ctx = CUASSERT( + driver.cuGreenCtxCreate( + desc, device, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM + ) + )[0] + stream = CUASSERT( + driver.cuGreenCtxStreamCreate( + green_ctx, driver.CUstream_flags.CU_STREAM_NON_BLOCKING, 0 + ) + )[0] + stream = torch.cuda.get_stream_from_external(stream, device) + streams.append(stream) + return streams + + +def green_ctx_split_percent(sm_percent: float, device_id: int = 0): + device = CUASSERT(driver.cuDeviceGet(device_id))[0] + + res = CUASSERT( + driver.cuDeviceGetDevResource(device, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM) + )[0] + sm_count = res.sm.smCount + + major = CUASSERT( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device + ) + )[0] + if major >= 9: + sm_min = 8 + sm_align = 8 + else: + sm_min = 4 if major == 8 else 2 + sm_align = 2 + + def green_ctx_split_aligned(sm_g1): + sm_g1 = round(sm_g1 / sm_align) * sm_align + sm_g1 = min(max(sm_g1, sm_min), sm_count - sm_min) + result = CUASSERT( + driver.cuDevSmResourceSplitByCount( + 1, # nbGroups + res, + 0, # useFlags + sm_g1, + ) + ) + res_split = (result[0][0], result[2]) + streams = green_ctx_create_streams(res_split, device) + return streams, res_split + + sm_g1 = round(sm_count * sm_percent) + sm_g2 = sm_count - sm_g1 + # Choose the split closer to sm_percent when sm_count is not divisible by sm_align + sm_g1_dist = min(sm_g1 % sm_align, sm_align - (sm_g1 % sm_align)) + sm_g2_dist = min(sm_g2 % sm_align, sm_align - (sm_g2 % sm_align)) + if sm_g1_dist <= sm_g2_dist: + (stream_g1, stream_g2), (res_g1, res_g2) = green_ctx_split_aligned(sm_g1) + else: + (stream_g2, stream_g1), (res_g2, res_g1) = green_ctx_split_aligned(sm_g2) + return (stream_g1, stream_g2), (res_g1, res_g2) diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index c525481fee3..e8a59b0e22d 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -797,3 +797,11 @@ def get_draft_token_length(request: LlmRequest) -> int: if request.py_draft_tokens is not None: return len(request.py_draft_tokens) return 0 + + +def get_context_requests(requests: List[LlmRequest]): + return [req for req in requests if req.is_context_init_state] + + +def get_generation_requests(requests: List[LlmRequest]): + return [req for req in requests if not req.is_context_init_state] diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 42f71181b11..f3471bf3bc5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -135,10 +135,12 @@ def __init__( attn_runtime_features: Optional[AttentionRuntimeFeatures] = None, dist: Optional[MPIDist] = None, spec_config: Optional["DecodingBaseConfig"] = None, + is_sm_disagg_ctx_phase: bool = False, is_draft_model: bool = False, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, model: Optional[torch.nn.Module] = None, + weight_sharing_model: Optional[torch.nn.Module] = None, ): self.forward_pass_callable = None self.ub_buffers = None @@ -148,6 +150,9 @@ def __init__( max_seq_len, max_batch_size, ) = llm_args.get_runtime_sizes() + if is_sm_disagg_ctx_phase: + max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens + max_batch_size = llm_args.sm_disagg_config.context_max_batch_size self.batch_size = max_batch_size self.max_num_tokens = max_num_tokens @@ -165,6 +170,7 @@ def __init__( if dist is not None: ExpertStatistic.create(self.dist.rank) self.llm_args = llm_args + self.sm_disagg_enabled = llm_args.sm_disagg_config is not None self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0 self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0 @@ -195,6 +201,7 @@ def __init__( max_num_tokens=self.max_num_tokens, max_seq_len=self.max_seq_len, lora_config=lora_config, + weight_sharing_model=weight_sharing_model, ) self.model, moe_load_balancer = loader.load( checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader) @@ -1434,8 +1441,10 @@ def _prepare_tp_inputs( # the request has no previous tensor: # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or # (2) a dummy request; or - # (3) the first step in the generation server of disaggregated serving - if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None: + # (3) the first step in the generation server of disaggregated serving; or + # (4) the first step in the generation phase of SM-level disaggregation + if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ + or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: # get token ids, including input token ids and draft token ids. For these dummy requests, # no need to copy the token ids. if not (request.is_attention_dp_dummy @@ -1559,8 +1568,10 @@ def _prepare_tp_inputs( # the request has no previous tensor: # (1) new_tokens_device is None, which means overlap scheduler is disabled; or # (2) a dummy request; or - # (3) the first step in the generation server of disaggregated serving - if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None: + # (3) the first step in the generation server of disaggregated serving; or + # (4) the first step in the generation phase of SM-level disaggregation + if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ + or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: # skip adding input_ids of CUDA graph dummy requests so that new_tokens_device # can be aligned to the correct positions. if not request.is_cuda_graph_dummy: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index b9c1377cd98..329ca42bdac 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -191,7 +191,8 @@ def __init__(self, sparse_attention_config: Optional["SparseAttentionConfig"], max_num_tokens: int, max_seq_len: Optional[int], - lora_config: Optional[LoraConfig] = None): + lora_config: Optional[LoraConfig] = None, + weight_sharing_model: Optional[torch.nn.Module] = None): """ Initializes the ModelLoader. @@ -210,6 +211,7 @@ def __init__(self, self.max_num_tokens = max_num_tokens self.max_seq_len = max_seq_len self.lora_config = lora_config + self.weight_sharing_model = weight_sharing_model def load( self, @@ -307,6 +309,12 @@ def init_meta_tensor(t: torch.Tensor): moe_load_balancer.finalize_model() logger.info("moe_load_balancer finalize model done") + if self.weight_sharing_model is not None: + model.load_state_dict(self.weight_sharing_model.state_dict(), + assign=True) + # Free up duplicate model weights allocated before weight sharing + torch.cuda.empty_cache() + torch.cuda.current_stream().synchronize() return model, moe_load_balancer diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 23ab0dbfa07..874714ac836 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1,5 +1,6 @@ import dataclasses import datetime +import enum import functools import os import pickle # nosec B403 @@ -42,13 +43,15 @@ from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.speculation_gate import SpeculationGate from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem +from .green_ctx import green_ctx_split_percent from .guided_decoder import GuidedDecoder from .handle_additional_outputs import HandleAdditionalOutputs from .handle_logits import HandleLogits from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, - LlmResponse, get_draft_token_length) + LlmResponse, get_context_requests, + get_draft_token_length, get_generation_requests) from .model_engine import ModelEngine from .resource_manager import ResourceManager from .sampler import Sampler, SampleState, SampleStateTensors @@ -58,6 +61,10 @@ # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP" +# Environment variable to switch between context/generation profiling ranges as specified in PROFILE_START_STOP_ENV_VAR_NAME +# Set to "1" for the specified ranges to be treated as context phase ranges +PROFILE_SM_DISAGG_CTX_RANGE_ENV_VAR_NAME = "TLLM_PROFILE_SM_DISAGG_CTX_RANGE" + # Environment variable to enable PyTorch profiler tracing. # Set to a path to save detailed tracing of PyTorch operations. PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" @@ -107,6 +114,12 @@ class BatchStatePP(BatchState): scheduled_ctx_reqs: list[LlmRequest] = None +class ExecutorLoopPhase(enum.StrEnum): + IFB = "IFB" + SM_DISAGG_CTX = "SM_DISAGG_CTX" + SM_DISAGG_GEN = "SM_DISAGG_GEN" + + class PyExecutor: def __init__(self, @@ -117,6 +130,8 @@ def __init__(self, dist: Distributed, max_num_sequences: int, drafter: Optional[Drafter] = None, + ctx_scheduler: Optional[RequestScheduler] = None, + ctx_model_engine: Optional[ModelEngine] = None, disable_overlap_scheduler: bool = False, max_input_len: int = 2048, max_batch_size: int = 8, @@ -155,6 +170,9 @@ def __init__(self, self.disable_overlap_scheduler = disable_overlap_scheduler self.virtual_memory_pools = virtual_memory_pools + self.ctx_scheduler = ctx_scheduler + self.ctx_model_engine = ctx_model_engine + # enqueue and _fetch_new_requests used data self.active = True self.max_beam_width = max_beam_width @@ -172,6 +190,8 @@ def __init__(self, if self.attention_dp_enable_balance: self.attention_dp_time_out_iters = self.llm_args.attention_dp_config.timeout_iters self.attention_dp_batching_wait_iters = self.llm_args.attention_dp_config.batching_wait_iters + if self.ctx_model_engine is not None: + self.sm_disagg_ctx_sm_percent = self.llm_args.sm_disagg_config.context_sm_percent self.batch_wait_timeout_ms = self.llm_args.batch_wait_timeout_ms self.batch_wait_timeout_iters = self.llm_args.batch_wait_timeout_iters self.batch_wait_max_tokens_ratio = self.llm_args.batch_wait_max_tokens_ratio @@ -201,6 +221,10 @@ def __init__(self, self.responses = {} self.result_wait_queues = {} + self.executor_lock = threading.Lock() + self.ctx_request_cv = threading.Condition(self.executor_lock) + self.gen_request_cv = threading.Condition(self.executor_lock) + # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) @@ -211,6 +235,9 @@ def __init__(self, self.max_input_len = max_input_len # _executor_loop private data self.max_num_active_requests = model_engine.get_max_num_sequences() + if self.ctx_model_engine is not None: + self.max_num_active_requests += ctx_model_engine.get_max_num_sequences( + ) self.active_requests: List[LlmRequest] = [] self.expected_num_active_requests = 0 self.ctx_in_transmission_requests = dict() @@ -219,7 +246,6 @@ def __init__(self, 0) self.previous_batch: Optional[BatchState] = None self.has_previous_draft_tokens = False - self.num_scheduled_requests: int = 0 self.benchmark_req_queues_size = int( os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0)) self._disable_mpi = mpi_disabled() @@ -234,9 +260,12 @@ def __init__(self, # During warmup, we don't enable the profiler self.is_warmup = True - self.model_engine.warmup(self.resource_manager) - if self.draft_model_engine is not None: - self.draft_model_engine.warmup(self.resource_manager) + for eng in [ + self.model_engine, self.ctx_model_engine, + self.draft_model_engine + ]: + if eng is not None: + eng.warmup(self.resource_manager) self.is_warmup = False self.is_shutdown = False @@ -255,6 +284,7 @@ def __init__(self, max_num_active_requests=self.max_num_active_requests, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, + is_sm_disagg=ctx_model_engine is not None, ) self.executor_request_queue.set_exclude_last_generation_logits( self.disable_overlap_scheduler, self.dist.pp_size) @@ -275,6 +305,8 @@ def __init__(self, if self.dist.pp_size > 1: self.event_loop = self._executor_loop_pp + elif ctx_model_engine is not None: + self.event_loop = self._executor_loop_sm_disagg else: self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): @@ -346,9 +378,12 @@ def is_warmup(self) -> bool: def is_warmup(self, value: bool): self._is_warmup = value # Set warmup flag in model engine to trigger torch compile and avoid moe load balancer statistics update - self.model_engine.is_warmup = value - if self.draft_model_engine is not None: - self.draft_model_engine.is_warmup = value + for eng in [ + self.model_engine, self.ctx_model_engine, + self.draft_model_engine + ]: + if eng is not None: + eng.is_warmup = value def start_worker(self): with self.worker_lock: @@ -443,6 +478,8 @@ def shutdown(self): if manager: manager.shutdown() del self.model_engine + if self.ctx_model_engine is not None: + del self.ctx_model_engine if self.draft_model_engine is not None: del self.draft_model_engine if self.virtual_memory_pools is not None: @@ -505,7 +542,13 @@ def should_stop_processing(self): self.executor_request_queue.get_waiting_queue_size() == 0 @contextmanager - def _profiler(self): + def _profiler(self, + model_engine: Optional[ModelEngine] = None, + stream: Optional[torch.cuda.Stream] = None, + phase: Optional[ExecutorLoopPhase] = ExecutorLoopPhase.IFB): + if model_engine is None: + model_engine = self.model_engine + it = -1 enabled = False start_time = None @@ -522,7 +565,13 @@ def _profiler(self): torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None) profile_start_stop = os.environ.get(PROFILE_START_STOP_ENV_VAR_NAME, None) - enable_torch_trace = bool(torch_trace_path and profile_start_stop) + profile_sm_disagg_ctx_range = bool( + os.environ.get(PROFILE_SM_DISAGG_CTX_RANGE_ENV_VAR_NAME, None)) + enable_profiler = phase == ExecutorLoopPhase.IFB or ( + (phase == ExecutorLoopPhase.SM_DISAGG_GEN) + ^ profile_sm_disagg_ctx_range) + enable_torch_trace = bool(enable_profiler and torch_trace_path + and profile_start_stop) if torch_trace_path and profile_start_stop is None: logger.warning( f"{PROFILE_START_STOP_ENV_VAR_NAME} environment variable " @@ -540,9 +589,9 @@ def _profiler(self): record_shapes=True, with_modules=True) - def profile_step(): + def profile_step(iter_stats: Optional[IterationStats] = None): nonlocal it, enabled, start_time, start_event_1, end_event_1, start_event_2, end_event_2, prev_device_step_time - if it in self.profile_stop_iters and not self.is_warmup: + if it in self.profile_stop_iters and enable_profiler and not self.is_warmup: assert enabled, "Inconsistent CUDA profiling state" if enable_torch_trace: torch_profiler.stop() @@ -554,18 +603,19 @@ def profile_step(): if start_time is not None and self.print_log and self.dist.rank == 0: end_time = time.time() - if it % 2 == 0: - end_event_1.record() - if start_event_2 is not None: - end_event_2.synchronize() - prev_device_step_time = start_event_2.elapsed_time( - end_event_2) - else: - end_event_2.record() - if start_event_1 is not None: - end_event_1.synchronize() - prev_device_step_time = start_event_1.elapsed_time( - end_event_1) + with torch.cuda.stream(stream): + if it % 2 == 0: + end_event_1.record() + if start_event_2 is not None: + end_event_2.synchronize() + prev_device_step_time = start_event_2.elapsed_time( + end_event_2) + else: + end_event_2.record() + if start_event_1 is not None: + end_event_1.synchronize() + prev_device_step_time = start_event_1.elapsed_time( + end_event_1) if prev_device_step_time is None: prev_device_step_time = "N/A" # Handle first iteration @@ -575,7 +625,8 @@ def profile_step(): formatted_timestamp = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") logger.info( - f"iter = {self.model_engine.iter_counter}, " + f"phase = {phase}, " + f"iter = {model_engine.iter_counter}, " f"global_rank = {self.global_rank}, " f"rank = {self.dist.rank}, " f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" @@ -583,12 +634,12 @@ def profile_step(): f"host_step_time = {host_step_time}ms, " f"prev_device_step_time = {prev_device_step_time}, " f"timestamp = {formatted_timestamp}, " - f"num_scheduled_requests: {self.num_scheduled_requests}, " - f"states = {self.model_engine.iter_states}") + f"num_scheduled_requests: {iter_stats.inflight_batching_stats.num_scheduled_requests if iter_stats else 'N/A'}, " + f"states = {model_engine.iter_states}") it += 1 - if it in self.profile_start_iters and not self.is_warmup: + if it in self.profile_start_iters and enable_profiler and not self.is_warmup: assert not enabled, "Inconsistent CUDA profiling state" torch.cuda.cudart().cudaProfilerStart() if enable_torch_trace: @@ -596,14 +647,15 @@ def profile_step(): logger.info(f"Profiling started at iteration {it}.") enabled = True start_time = time.time() - if it % 2 == 0: - if start_event_1 is None: - start_event_1 = torch.cuda.Event(enable_timing=True) - start_event_1.record() - else: - if start_event_2 is None: - start_event_2 = torch.cuda.Event(enable_timing=True) - start_event_2.record() + with torch.cuda.stream(stream): + if it % 2 == 0: + if start_event_1 is None: + start_event_1 = torch.cuda.Event(enable_timing=True) + start_event_1.record() + else: + if start_event_2 is None: + start_event_2 = torch.cuda.Event(enable_timing=True) + start_event_2.record() try: yield profile_step @@ -705,8 +757,6 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, stats.cpu_mem_usage = 0 stats.pinned_mem_usage = 0 - stats.iter = self.model_engine.iter_counter - kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) if kv_cache_manager is not None: @@ -789,7 +839,7 @@ def _executor_loop_pp(self): iter_start_time = time.time() iter_stats = None while True: - profile_step() + profile_step(iter_stats) if self.enable_iter_perf_stats: iter_start_time = time.time() new_requests = self._fetch_and_activate_new_requests() @@ -823,8 +873,6 @@ def _executor_loop_pp(self): ) self._check_disagg_ctx_cache_transfer_status(1) - self.num_scheduled_requests = scheduled_batch.batch_size - logger.debug( f'has {len(self.active_requests)} active_request, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' @@ -888,6 +936,7 @@ def _executor_loop_pp(self): self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: + iter_stats.iter_counter = self.model_engine.iter_counter iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] batch_state = BatchStatePP( @@ -1083,13 +1132,66 @@ def _prepare_and_schedule_batch(self): ) self._check_disagg_ctx_cache_transfer_status(1) - self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( f'has {len(self.active_requests)} active_request, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' f'{len(scheduled_batch.generation_requests)} generation requests') return scheduled_batch, iter_stats + def _prepare_and_schedule_batch_sm_disagg_ctx(self): + new_requests = self._fetch_and_activate_new_requests() + + iter_stats = None + if self.enable_iter_perf_stats: + iter_stats = self._get_init_iter_stats( + len(new_requests), + self.executor_request_queue. + get_new_active_requests_queue_latency()) + + with self.executor_lock: + ctx_requests = get_context_requests(self.active_requests) + if self.is_shutdown and len(ctx_requests) == 0 \ + and self.executor_request_queue.get_waiting_queue_size() == 0: + self.ctx_request_cv.notify() + return None, None + + scheduled_batch, _, _ = self._schedule(scheduler=self.ctx_scheduler) + + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch.context_requests)} context requests' + ) + + if scheduled_batch.batch_size == 0 \ + and (len(ctx_requests) > 0 or self.executor_request_queue.get_waiting_queue_size() > 0): + self.gen_request_cv.wait() + + return scheduled_batch, iter_stats + + def _prepare_and_schedule_batch_sm_disagg_gen(self): + if self.should_stop_processing: + return None, None + + iter_stats = None + if self.enable_iter_perf_stats: + iter_stats = self._get_init_iter_stats( + num_new_active_requests=0, + new_active_requests_queue_latency_ms=0) + + with self.executor_lock: + gen_requests = get_generation_requests(self.active_requests) + scheduled_batch, _, _ = self._schedule(active_requests=gen_requests) + + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch.generation_requests)} generation requests' + ) + + if scheduled_batch.batch_size == 0: + self.ctx_request_cv.wait() + + return scheduled_batch, iter_stats + def _kv_connector_start_batch(self, scheduled_batch): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( @@ -1116,45 +1218,60 @@ def _kv_connector_wait_for_save(self): self.kv_connector_manager.worker.wait_for_save( torch.cuda.current_stream()) - def _executor_loop(self): + def _executor_loop(self, + stream: Optional[torch.cuda.Stream] = None, + phase: ExecutorLoopPhase = ExecutorLoopPhase.IFB): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + + model_engine = self.ctx_model_engine if phase == ExecutorLoopPhase.SM_DISAGG_CTX else self.model_engine + + with self._profiler(model_engine, stream, phase) as profile_step: sample_state = None iter_start_time = time.time() iter_stats = None while True: - profile_step() + profile_step(iter_stats) if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if phase == ExecutorLoopPhase.SM_DISAGG_CTX: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch_sm_disagg_ctx( + ) + elif phase == ExecutorLoopPhase.SM_DISAGG_GEN: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch_sm_disagg_gen( + ) + else: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch( + ) + self._handle_control_request() if scheduled_batch is None: break - self._pause_requests(scheduled_batch.paused_requests) + with self.executor_lock: + self._pause_requests(scheduled_batch.paused_requests) - finished_requests = [] + finished_requests = [] - can_queue = self._can_queue(scheduled_batch) - if can_queue: - if self.kv_cache_transceiver: - # For generation requests which have completed KV cache transfer - self._prepare_disagg_gen_transmission_complete( - scheduled_batch) + can_queue = self._can_queue(scheduled_batch) + if can_queue: + if self.kv_cache_transceiver: + # For generation requests which have completed KV cache transfer + self._prepare_disagg_gen_transmission_complete( + scheduled_batch) - # Return the first token to the client - self._handle_first_token_response(scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) + # Return the first token to the client + self._handle_first_token_response(scheduled_batch) + self.resource_manager.prepare_resources(scheduled_batch) - self._kv_connector_start_batch(scheduled_batch) + self._kv_connector_start_batch(scheduled_batch) - # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed - if self.kv_connector_manager: - can_queue = self._can_queue(scheduled_batch) + # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed + if self.kv_connector_manager: + can_queue = self._can_queue(scheduled_batch) if can_queue: # init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers. @@ -1181,49 +1298,61 @@ def _executor_loop(self): if hasattr(self.drafter, "guided_decoder"): self.guided_decoder.rollback_draft_tokens() - batch_outputs = self._forward_step(scheduled_batch) - if self.guided_decoder is not None: - self.guided_decoder.execute(batch_outputs['logits']) + with torch.cuda.stream(stream): + batch_outputs = self._forward_step( + scheduled_batch, model_engine=model_engine) + if self.guided_decoder is not None: + self.guided_decoder.execute(batch_outputs['logits']) + + sample_state = self._sample_async( + scheduled_batch, batch_outputs) + if self.drafter is not None: + self.drafter.run_drafter_post( + scheduled_batch, self.resource_manager, + self.is_warmup) + # To avoid long sync time in critical section below + if phase == ExecutorLoopPhase.SM_DISAGG_CTX or phase == ExecutorLoopPhase.SM_DISAGG_GEN: + sample_state.sampler_event.synchronize() + + with self.executor_lock: + self._update_request_states(scheduled_batch) + self._update_requests(sample_state, + self.resource_manager) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in scheduled_batch.context_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests[ + req.py_request_id] = ( + (req, block_id, + self.ctx_in_transmission_counter)) - sample_state = self._sample_async(scheduled_batch, - batch_outputs) - if self.drafter is not None: - self.drafter.run_drafter_post(scheduled_batch, - self.resource_manager, - self.is_warmup) - - self._update_request_states(scheduled_batch) - self._update_requests(sample_state, self.resource_manager) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in scheduled_batch.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) + if self.kv_cache_transceiver: + ctx_transmission_reqs = self._send_disagg_ctx_cache( + scheduled_batch.context_requests) + # For context only req in transmission, we reset the state since sampler might have changed it + for req in ctx_transmission_reqs: + req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - if self.kv_cache_transceiver: - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests) - # For context only req in transmission, we reset the state since sampler might have changed it - for req in ctx_transmission_reqs: - req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - - self._handle_canceled_requests() - finished_requests = self._handle_responses() - attn_metadata = getattr(self.model_engine, 'attn_metadata', - None) - kv_cache_dtype_byte_size = getattr( - self.model_engine, 'kv_cache_dtype_byte_size', None) - self.resource_manager.update_resources( - scheduled_batch, attn_metadata, - kv_cache_dtype_byte_size) - if self.enable_kv_cache_events: - self._add_kv_cache_events() + self._handle_canceled_requests() + finished_requests = self._handle_responses() + attn_metadata = getattr(model_engine, 'attn_metadata', + None) + kv_cache_dtype_byte_size = getattr( + model_engine, 'kv_cache_dtype_byte_size', None) + self.resource_manager.update_resources( + scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + if self.enable_kv_cache_events: + self._add_kv_cache_events() + + if phase == ExecutorLoopPhase.SM_DISAGG_CTX: + self.ctx_request_cv.notify() + elif phase == ExecutorLoopPhase.SM_DISAGG_GEN: + self.gen_request_cv.notify() if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._check_kv_transfer_timeout() @@ -1232,7 +1361,8 @@ def _executor_loop(self): self._kv_connector_terminate_requests() if self.enable_iter_perf_stats and sample_state is not None: - iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + iter_stats.iter_counter = model_engine.iter_counter + iter_stats.inflight_batching_stats.num_ctx_tokens = model_engine.iter_states[ 'num_ctx_tokens'] self._process_iter_stats( finished_requests, self.active_requests, @@ -1306,22 +1436,35 @@ def control_action(self): self.control_action_done.set() self.control_request_barrier.clear() - def _executor_loop_overlap(self): + def _executor_loop_overlap( + self, + stream: Optional[torch.cuda.Stream] = None, + phase: ExecutorLoopPhase = ExecutorLoopPhase.IFB): + assert phase != ExecutorLoopPhase.SM_DISAGG_CTX, ( + "SM disagg context phase only supports non-overlapping executor loop." + ) torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + + with self._profiler(self.model_engine, stream, phase) as profile_step: iter_start_time = time.time() iter_stats = None target_inputs = None previous_tensors_device = None can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True while True: - profile_step() + profile_step(iter_stats) if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if phase == ExecutorLoopPhase.SM_DISAGG_GEN: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch_sm_disagg_gen( + ) + else: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch( + ) + self._handle_control_request() if scheduled_batch is None: @@ -1357,32 +1500,36 @@ def _executor_loop_overlap(self): else: can_forward = True - self._pause_requests(scheduled_batch.paused_requests) + with self.executor_lock: + self._pause_requests(scheduled_batch.paused_requests) - can_queue = self._can_queue(scheduled_batch) - if can_queue: - if self.kv_cache_transceiver: - # For generation requests which have completed KV cache transfer - self._prepare_disagg_gen_transmission_complete( - scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) + can_queue = self._can_queue(scheduled_batch) + if can_queue: + if self.kv_cache_transceiver: + # For generation requests which have completed KV cache transfer + self._prepare_disagg_gen_transmission_complete( + scheduled_batch) + self.resource_manager.prepare_resources(scheduled_batch) - self._kv_connector_start_batch(scheduled_batch) + self._kv_connector_start_batch(scheduled_batch) - # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed - if self.kv_connector_manager: - can_queue = self._can_queue(scheduled_batch) + # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed + if self.kv_connector_manager: + can_queue = self._can_queue(scheduled_batch) if can_queue: - # The generation requests that are do not have batch_idx, + # The generation requests that are do not have batch_idx (for disaggregated serving), + # or those just finished context phase (for SM-level disaggregation), # needs to be in front of the batch due to the assumptions - # made in model_engine.py::_forward_step. This is only important - # for disaggregated serving. For non-disaggregated serving, + # made in model_engine.py::_forward_step. For non-disaggregated serving, # the generation requests always have batch_idx. scheduled_batch.generation_requests = sorted( # stable sort scheduled_batch.generation_requests, - key=lambda req: int(req.py_batch_idx is not None), + key=lambda req: + int(req.py_batch_idx is not None or phase == + ExecutorLoopPhase.SM_DISAGG_GEN and req. + max_num_generated_tokens > 0), ) if self.kv_cache_transceiver: @@ -1414,11 +1561,18 @@ def _executor_loop_overlap(self): else: previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device - batch_outputs = self._forward_step(scheduled_batch, - previous_tensors_device) + with torch.cuda.stream(stream): + batch_outputs = self._forward_step( + scheduled_batch, previous_tensors_device) + # To avoid long sync time in critical section below + if phase == ExecutorLoopPhase.SM_DISAGG_GEN and self.previous_batch is not None: + self.previous_batch.sample_state.sampler_event.synchronize( + ) if self.previous_batch is not None: - self._update_requests(self.previous_batch.sample_state) + with self.executor_lock: + self._update_requests( + self.previous_batch.sample_state) if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: for req in self.previous_batch.sample_state.scheduled_requests.context_requests: @@ -1441,20 +1595,26 @@ def _executor_loop_overlap(self): self.guided_decoder.add_batch(scheduled_batch) self.guided_decoder.execute(batch_outputs['logits']) - sample_state = self._sample_async(scheduled_batch, - batch_outputs) - assert sample_state is not None, "Sampling failed" + with torch.cuda.stream(stream): + sample_state = self._sample_async( + scheduled_batch, batch_outputs) + assert sample_state is not None, "Sampling failed" - self._update_request_states(scheduled_batch) + with self.executor_lock: + self._update_request_states(scheduled_batch) - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests - ) if self.kv_cache_transceiver else [] + ctx_transmission_reqs = self._send_disagg_ctx_cache( + scheduled_batch.context_requests + ) if self.kv_cache_transceiver else [] - if self.previous_batch is not None: - self._process_previous_batch() + if self.previous_batch is not None: + self._process_previous_batch() + + if phase == ExecutorLoopPhase.SM_DISAGG_GEN: + self.gen_request_cv.notify() if self.enable_iter_perf_stats: + iter_stats.iter_counter = self.model_engine.iter_counter iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] @@ -1470,6 +1630,28 @@ def _executor_loop_overlap(self): self._kv_connector_terminate_requests() + def _executor_loop_sm_disagg(self): + (stream_ctx, stream_gen), (res_ctx, res_gen) = green_ctx_split_percent( + self.sm_disagg_ctx_sm_percent, self.device_id) + logger.info( + f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase." + ) + + # Context phase only supports non-overlapping executor loop + thread_ctx = threading.Thread(target=self._executor_loop, + args=(stream_ctx, + ExecutorLoopPhase.SM_DISAGG_CTX), + daemon=True) + thread_ctx.start() + + if self.disable_overlap_scheduler: + self._executor_loop(stream_gen, ExecutorLoopPhase.SM_DISAGG_GEN) + else: + self._executor_loop_overlap(stream_gen, + ExecutorLoopPhase.SM_DISAGG_GEN) + + thread_ctx.join() + def _accept_draft_tokens( self, scheduled_batch: ScheduledRequests, target_outputs: SampleStateTensors, @@ -1629,8 +1811,13 @@ def _respond_if_invalid(request: LlmRequest) -> bool: self._handle_errors(str(e), requests=[request]) return True + if self.ctx_model_engine is not None: + num_active_requests_on_engine = len( + get_context_requests(self.active_requests)) + else: + num_active_requests_on_engine = len(self.active_requests) new_requests_cur_rank = self.executor_request_queue.fetch_new_requests( - self.active_requests) + self.active_requests, num_active_requests_on_engine) self.is_shutdown = self.executor_request_queue.is_shutdown self.expected_num_active_requests = self.executor_request_queue.get_expected_num_active_requests( ) @@ -1724,9 +1911,15 @@ def _waiting_requests(self, context_requests: list[LlmRequest], return waited_context_requests @nvtx_range("_schedule") - def _schedule(self): - scheduler_output = self.scheduler.schedule_request( - self.active_requests, self.inflight_req_ids) + def _schedule(self, + scheduler: Optional[RequestScheduler] = None, + active_requests: Optional[List[LlmRequest]] = None): + if scheduler is None: + scheduler = self.scheduler + if active_requests is None: + active_requests = self.active_requests + scheduler_output = scheduler.schedule_request(active_requests, + self.inflight_req_ids) scheduled_context_requests = scheduler_output.context_requests if self.enable_attention_dp and self.attention_dp_enable_balance: scheduled_context_requests = self._balance_adp_requests( @@ -1960,14 +2153,17 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0): def _forward_step(self, scheduled_requests, - new_tensors_device: Optional[SampleStateTensors] = None): + new_tensors_device: Optional[SampleStateTensors] = None, + model_engine: Optional[ModelEngine] = None): + if model_engine is None: + model_engine = self.model_engine @nvtx_range( - f"[Executor] _forward_step {self.model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" + f"[Executor] _forward_step {model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" ) def forward(scheduled_requests, resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer): - return self.model_engine.forward( + return model_engine.forward( scheduled_requests, resource_manager, new_tensors_device, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 008efc0405c..5019dd26296 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -79,6 +79,8 @@ def _bytes_to_gib(bytes: int) -> float: "KV cache", ExecutorMemoryType.MODEL_ENGINE_MAIN: "Model", + ExecutorMemoryType.MODEL_ENGINE_CTX: + "Context model for SM-level disaggregation", ExecutorMemoryType.MODEL_ENGINE_DRAFT: "Draft model for speculative decoding", } @@ -96,6 +98,9 @@ def _bytes_to_gib(bytes: int) -> float: ExecutorMemoryType.INIT_KV_CACHE: "reduce max_num_tokens", ExecutorMemoryType.MODEL_ENGINE_MAIN: + ("reduce max_num_tokens and/or shard the model weights across GPUs by enabling " + "pipeline and/or tensor parallelism"), + ExecutorMemoryType.MODEL_ENGINE_CTX: ("reduce max_num_tokens and/or shard the model weights across GPUs by enabling " "pipeline and/or tensor parallelism"), ExecutorMemoryType.MODEL_ENGINE_DRAFT: @@ -345,6 +350,36 @@ def allocation_scope(current_stage: ExecutorMemoryType, validate_feature_combination(llm_args, model_engine, llm_args.sampler_type) + if llm_args.sm_disagg_config is not None: + if llm_args.cache_transceiver_config is not None: + raise ValueError( + "SM-level disaggregation is not compatible with disaggregated serving." + ) + if llm_args.parallel_config.world_size > 1: + raise NotImplementedError( + "SM-level disaggregation is not supported with parallelism.") + if scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + raise NotImplementedError( + "SM-level disaggregation is only supported with guaranteed no evict scheduler policy." + ) + + with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX, + RestoreMode.PINNED): + ctx_llm_args = copy.copy(llm_args) + ctx_llm_args.cuda_graph_config = None + ctx_model_engine = PyTorchModelEngine( + model_path=checkpoint_dir, + llm_args=ctx_llm_args, + mapping=mapping, + attn_runtime_features=attn_runtime_features, + dist=dist, + spec_config=spec_config, + is_sm_disagg_ctx_phase=True, + weight_sharing_model=model_engine.model, + ) + else: + ctx_model_engine = None + if has_draft_model_engine: with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_DRAFT, RestoreMode.PINNED): @@ -447,9 +482,9 @@ def drafting_loop_wrapper(model): "Chunked Prefill for MLA can only be enabled on SM90/SM100/SM103/SM120, " f"disable enable_chunked_context for SM{sm_version}") enable_chunked_context = False - model_engine.attn_runtime_features.chunked_prefill = False - if draft_model_engine is not None: - draft_model_engine.attn_runtime_features.chunked_prefill = False + for eng in [model_engine, ctx_model_engine, draft_model_engine]: + if eng is not None: + eng.attn_runtime_features.chunked_prefill = False if enable_chunked_context: chunk_unit_size = tokens_per_block @@ -642,6 +677,8 @@ def drafting_loop_wrapper(model): max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_num_tokens=max_num_tokens, + ctx_model_engine=ctx_model_engine, + sm_disagg_config=llm_args.sm_disagg_config, peft_cache_config=peft_cache_config, scheduler_config=scheduler_config, cache_transceiver_config=cache_transceiver_config, @@ -670,11 +707,11 @@ def drafting_loop_wrapper(model): max_seq_len = kv_cache_creator._max_seq_len update_sampler_max_seq_len(max_seq_len, sampler) - for eng in [model_engine, draft_model_engine]: + for eng in [model_engine, ctx_model_engine, draft_model_engine]: if eng is None: continue if eng.attn_metadata is not None: - if llm_args.cuda_graph_config is not None: + if eng.cuda_graph_runner.enabled: eng._release_cuda_graphs() eng.attn_metadata = None @@ -699,6 +736,8 @@ def drafting_loop_wrapper(model): max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_num_tokens=max_num_tokens, + ctx_model_engine=ctx_model_engine, + sm_disagg_config=llm_args.sm_disagg_config, peft_cache_config=peft_cache_config, scheduler_config=scheduler_config, cache_transceiver_config=cache_transceiver_config, diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index c71c4596ed7..bfc9eed8fed 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -79,6 +79,9 @@ def __init__( scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. GUARANTEED_NO_EVICT, two_step_lookahead: bool = False, + no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, + no_schedule_after_state: LlmRequestState = LlmRequestState. + GENERATION_COMPLETE, ): super(BindCapacityScheduler, self).__init__() self.kv_cache_manager = kv_cache_manager @@ -89,8 +92,8 @@ def __init__( capacity_scheduler_policy=scheduler_policy._to_pybind(), has_kv_cache_manager=kv_cache_manager is not None, two_step_lookahead=two_step_lookahead, - no_schedule_until_state=LlmRequestState.CONTEXT_INIT, - no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE) + no_schedule_until_state=no_schedule_until_state, + no_schedule_after_state=no_schedule_after_state) def schedule_request( self, active_requests: RequestList @@ -175,6 +178,9 @@ def __init__( max_batch_size: int, max_num_tokens: int = None, ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None, + no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, + no_schedule_after_state: LlmRequestState = LlmRequestState. + GENERATION_COMPLETE, ) -> None: super(BindMicroBatchScheduler, self).__init__() self.max_batch_size = max_batch_size @@ -186,7 +192,8 @@ def __init__( ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1]) self.impl = tb_internal.algorithms.MicroBatchScheduler( - ctx_chunk_config_cpp, max_num_tokens) + ctx_chunk_config_cpp, max_num_tokens, no_schedule_until_state, + no_schedule_after_state) def schedule( self, active_requests: RequestList, inflight_request_ids: set[int] diff --git a/tensorrt_llm/_torch/virtual_memory.py b/tensorrt_llm/_torch/virtual_memory.py index 3702d732539..4ab9c21a214 100644 --- a/tensorrt_llm/_torch/virtual_memory.py +++ b/tensorrt_llm/_torch/virtual_memory.py @@ -78,6 +78,7 @@ class ExecutorMemoryType(StrEnum): EXTRA_RESOURCES = "executor_extra" KV_CACHE = "kv_cache" MODEL_ENGINE_MAIN = "model" + MODEL_ENGINE_CTX = "ctx_model" MODEL_ENGINE_DRAFT = "draft_model" diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index cb868d8d068..9c2459eeda9 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -14,8 +14,8 @@ MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, RocketSparseAttentionConfig, SaveHiddenStatesDecodingConfig, SchedulerConfig, - TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, - UserProvidedDecodingConfig) + SmDisaggConfig, TorchCompileConfig, TorchLlmArgs, + TrtLlmArgs, UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, QuantConfig) from .mm_encoder import MultimodalEncoder @@ -62,6 +62,7 @@ 'AttentionDpConfig', 'LoRARequest', 'SaveHiddenStatesDecodingConfig', + 'SmDisaggConfig', 'RocketSparseAttentionConfig', 'DeepSeekSparseAttentionConfig', ] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 40d967a2cb8..4554189b4c3 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -424,6 +424,29 @@ def from_dict(cls, data: dict): return cls(**data) +class SmDisaggConfig(StrictBaseModel): + """ + Configuration for SM-level disaggregation. + """ + context_sm_percent: float = Field( + default=0.5, + description="Percentage of SMs allocated to context phase.") + context_max_num_tokens: int = Field( + default=0, + description= + "The maximum number of tokens for context phase. If less than or equal to 0, the same value as generation phase is used." + ) + context_max_batch_size: int = Field( + default=0, + description= + "The maximum batch size for context phase. If less than or equal to 0, the same value as generation phase is used." + ) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + class _ParallelConfig(StrictBaseModel): """The model distribution configs for LLM.""" tp_size: int = 1 @@ -2556,6 +2579,11 @@ class TorchLlmArgs(BaseLlmArgs): description="Disable the overlap scheduler.", status="beta") + sm_disagg_config: Optional[SmDisaggConfig] = Field( + default=None, + description="SM-level disaggregation config.", + status="prototype") + moe_config: MoeConfig = Field(default_factory=MoeConfig, description="MoE config.", status="beta") @@ -2914,6 +2942,23 @@ def validate_attention_dp_config(self) -> 'TorchLlmArgs': ) return self + @model_validator(mode='after') + def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs': + """Validate SM-level disaggregation configuration.""" + if self.sm_disagg_config is None: + return self + + config = self.sm_disagg_config + if not 0 < config.context_sm_percent < 1: + raise ValueError( + "sm_disagg_config.context_sm_percent must be in the range (0, 1)" + ) + if config.context_max_num_tokens <= 0: + config.context_max_num_tokens = self.max_num_tokens + if config.context_max_batch_size <= 0: + config.context_max_batch_size = self.max_batch_size + return self + @model_validator(mode='after') def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': """Validate batch wait timeout."""