From 057360d9d32915f97b04d6cf930277836fb8c368 Mon Sep 17 00:00:00 2001 From: Qiang Xu Date: Wed, 29 Oct 2025 18:05:32 +0000 Subject: [PATCH 1/6] duplicate executor loops for sm disagg Signed-off-by: Qiang Xu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 272 ++++++++++++++++++ 1 file changed, 272 insertions(+) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 23ab0dbfa07..3223d2dc877 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1470,6 +1470,278 @@ def _executor_loop_overlap(self): self._kv_connector_terminate_requests() + def _executor_loop_sm_disagg_ctx(self): + torch.cuda.set_device(self.device_id) + # ensure the context is created, otherwise, some MPI calls will fail. + CUASSERT(cudart.cudaSetDevice(self.device_id)) + with self._profiler() as profile_step: + sample_state = None + iter_start_time = time.time() + iter_stats = None + while True: + profile_step() + if self.enable_iter_perf_stats: + iter_start_time = time.time() + + scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if scheduled_batch is None: + break + + self._pause_requests(scheduled_batch.paused_requests) + + finished_requests = [] + + if scheduled_batch.batch_size > 0 or ( + self.enable_attention_dp and self.dist.tp_size > 1): + if self.kv_cache_transceiver: + # For generation requests which have completed KV cache transfer + self._prepare_disagg_gen_transmission_complete( + scheduled_batch) + + # Return the first token to the client + self._handle_first_token_response(scheduled_batch) + self.resource_manager.prepare_resources(scheduled_batch) + + self._kv_connector_start_batch(scheduled_batch) + + if scheduled_batch.batch_size > 0 or ( + self.enable_attention_dp and self.dist.tp_size > 1): + # init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers. + # init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated. + if self.guided_decoder is not None: + self.guided_decoder.add_batch(scheduled_batch) + if self.kv_cache_transceiver: + self.guided_decoder.init_disagg_gen_requests() + + if self.drafter is not None and self.use_spec_decode: + if self.guided_decoder is not None: + self.guided_decoder.rollback_rejected_tokens() + with request_context( + is_draft=self.draft_model_engine is not None, + scheduled_requests=scheduled_batch): + self.drafter.prepare_draft_tokens( + scheduled_batch, self.resource_manager) + # Pad draft tokens to the max draft length. This is for CUDA graph compatibility. + self.drafter.pad_draft_tokens_for_cuda_graph( + scheduled_batch) + # add_batch must be called again to restore to target requests with updated draft tokens. + if self.guided_decoder is not None: + self.guided_decoder.add_batch(scheduled_batch) + if hasattr(self.drafter, "guided_decoder"): + self.guided_decoder.rollback_draft_tokens() + + batch_outputs = self._forward_step(scheduled_batch) + if self.guided_decoder is not None: + self.guided_decoder.execute(batch_outputs['logits']) + + sample_state = self._sample_async(scheduled_batch, + batch_outputs) + if self.drafter is not None: + self.drafter.run_drafter_post(scheduled_batch, + self.resource_manager, + self.is_warmup) + + self._update_request_states(scheduled_batch) + self._update_requests(sample_state, self.resource_manager) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in scheduled_batch.context_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests[ + req.py_request_id] = ( + (req, block_id, + self.ctx_in_transmission_counter)) + + if self.kv_cache_transceiver: + ctx_transmission_reqs = self._send_disagg_ctx_cache( + scheduled_batch.context_requests) + # For context only req in transmission, we reset the state since sampler might have changed it + for req in ctx_transmission_reqs: + req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + + self._handle_canceled_requests() + finished_requests = self._handle_responses() + attn_metadata = getattr(self.model_engine, 'attn_metadata', + None) + kv_cache_dtype_byte_size = getattr( + self.model_engine, 'kv_cache_dtype_byte_size', None) + self.resource_manager.update_resources( + scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + if self.enable_kv_cache_events: + self._add_kv_cache_events() + + if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() + self._terminate_disagg_ctx_finished_requests() + + self._kv_connector_terminate_requests() + + if self.enable_iter_perf_stats and sample_state is not None: + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + 'num_ctx_tokens'] + self._process_iter_stats( + finished_requests, self.active_requests, + BatchState(sample_state=sample_state, + iter_stats=iter_stats, + iter_start_time=iter_start_time)) + + def _executor_loop_sm_disagg_gen_overlap(self): + torch.cuda.set_device(self.device_id) + # ensure the context is created, otherwise, some MPI calls will fail. + CUASSERT(cudart.cudaSetDevice(self.device_id)) + with self._profiler() as profile_step: + iter_start_time = time.time() + iter_stats = None + can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True + while True: + profile_step() + if self.enable_iter_perf_stats: + iter_start_time = time.time() + + scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if scheduled_batch is None: + break + # In gen-only benchmarking mode, wait until the number of scheduled generation + # requests reaches the required threshold before starting forward pass, + # to ensure consistent batch sizes for accurate performance measurement. + if not self.is_warmup and not can_forward: + if self.enable_attention_dp: + local_can_forward = self.executor_request_queue.num_fetch_requests + \ + len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size + all_can_forward = self.dist.tp_allgather( + local_can_forward) + if all(all_can_forward): + can_forward = True + time.sleep(10) + else: + if self.dist.rank == 0: + logger.info( + f"sleep 10 seconds, num_fetched_requests: {self.executor_request_queue.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}" + ) + time.sleep(10) + continue + else: + if len(scheduled_batch.generation_requests + ) < self.benchmark_req_queues_size: + if self.dist.rank == 0: + logger.info( + f"sleep 10 seconds, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}" + ) + time.sleep(10) + continue + else: + can_forward = True + + self._pause_requests(scheduled_batch.paused_requests) + + if scheduled_batch.batch_size > 0: + if self.kv_cache_transceiver: + # For generation requests which have completed KV cache transfer + self._prepare_disagg_gen_transmission_complete( + scheduled_batch) + self.resource_manager.prepare_resources(scheduled_batch) + + self._kv_connector_start_batch(scheduled_batch) + + if scheduled_batch.batch_size > 0: + + # The generation requests that are do not have batch_idx, + # needs to be in front of the batch due to the assumptions + # made in model_engine.py::_forward_step. This is only important + # for disaggregated serving. For non-disaggregated serving, + # the generation requests always have batch_idx. + scheduled_batch.generation_requests = sorted( # stable sort + scheduled_batch.generation_requests, + key=lambda req: int(req.py_batch_idx is not None), + ) + + if self.kv_cache_transceiver: + # Return the first token to the client + self._handle_first_token_response(scheduled_batch) + + # init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated. + if self.guided_decoder is not None and self.kv_cache_transceiver: + self.guided_decoder.add_batch(scheduled_batch) + self.guided_decoder.init_disagg_gen_requests() + + previous_tensors = self.previous_batch and self.previous_batch.sample_state + target_inputs = None + draft_outputs = None + # If there are previous draft tokens, we need to update the target requests to accept some draft tokens. + # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model, + # so we'll set the target model's input to None and skip updating the target requests after target model forward. + use_previous_draft_tokens = self.has_previous_draft_tokens + if self.drafter is not None and (self.use_spec_decode or + use_previous_draft_tokens): + target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding( + scheduled_batch, previous_tensors) + + # Use the draft_model's outputs if we've launched the draft model. + # Otherwise, use the previous batch's outputs. + if target_inputs is not None or use_previous_draft_tokens: + previous_tensors_device = target_inputs + else: + previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device + + batch_outputs = self._forward_step(scheduled_batch, + previous_tensors_device) + + if target_inputs is not None: + self._process_draft_results(scheduled_batch, + draft_outputs, draft_batch) + elif self.previous_batch is not None and not use_previous_draft_tokens: + self._update_requests(self.previous_batch.sample_state) + + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in self.previous_batch.sample_state.scheduled_requests.context_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests[ + req.py_request_id] = ( + (req, block_id, + self.ctx_in_transmission_counter)) + + if self.guided_decoder is not None: + # add_batch must be called again to have updated new tokens. + self.guided_decoder.add_batch(scheduled_batch) + self.guided_decoder.execute(batch_outputs['logits']) + + sample_state = self._sample_async(scheduled_batch, + batch_outputs) + assert sample_state is not None, "Sampling failed" + + self._update_request_states(scheduled_batch) + + ctx_transmission_reqs = self._send_disagg_ctx_cache( + scheduled_batch.context_requests + ) if self.kv_cache_transceiver else [] + + if self.previous_batch is not None: + self._process_previous_batch() + + if self.enable_iter_perf_stats: + iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + 'num_ctx_tokens'] + + self.previous_batch = BatchState( + sample_state=sample_state, + iter_start_time=iter_start_time, + iter_stats=iter_stats, + ctx_transmission_reqs=ctx_transmission_reqs) + + if self.kv_cache_transceiver and self.ctx_in_transmission_requests: + self._check_kv_transfer_timeout() + self._terminate_disagg_ctx_finished_requests() + + self._kv_connector_terminate_requests() + def _accept_draft_tokens( self, scheduled_batch: ScheduledRequests, target_outputs: SampleStateTensors, From e2dc09c6ffc3bba3adfbc3556cf7d1a918408ee8 Mon Sep 17 00:00:00 2001 From: Qiang Xu Date: Thu, 6 Nov 2025 00:50:33 +0000 Subject: [PATCH 2/6] sm disagg implementation Signed-off-by: Qiang Xu --- tensorrt_llm/_torch/pyexecutor/_util.py | 23 +- .../pyexecutor/executor_request_queue.py | 40 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 8 + .../_torch/pyexecutor/model_engine.py | 19 +- .../_torch/pyexecutor/model_loader.py | 10 +- tensorrt_llm/_torch/pyexecutor/py_executor.py | 491 +++++++++--------- .../_torch/pyexecutor/py_executor_creator.py | 53 +- tensorrt_llm/_torch/pyexecutor/scheduler.py | 27 +- tensorrt_llm/_torch/virtual_memory.py | 1 + tensorrt_llm/llmapi/__init__.py | 5 +- tensorrt_llm/llmapi/llm_args.py | 45 ++ 11 files changed, 460 insertions(+), 262 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 8da982aba2b..a79e6401703 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -14,7 +14,7 @@ EagleDecodingConfig, KvCacheConfig, MTPDecodingConfig, PeftCacheConfig, SamplerType, SchedulerConfig, - SparseAttentionConfig, + SmDisaggConfig, SparseAttentionConfig, SpeculativeConfig, TorchLlmArgs) from tensorrt_llm.logger import logger from tensorrt_llm.lora_helper import (LoraConfig, @@ -38,7 +38,7 @@ from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler, TRTLLMSampler) from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, - SimpleScheduler) + SimpleScheduler, SmDisaggCtxScheduler) from .seq_slot_manager import SeqSlotManager GB = 1 << 30 @@ -665,6 +665,8 @@ def create_py_executor_instance( max_batch_size: Optional[int] = None, max_beam_width: Optional[int] = None, max_num_tokens: Optional[int] = None, + ctx_model_engine: Optional[PyTorchModelEngine] = None, + sm_disagg_config: Optional[SmDisaggConfig] = None, peft_cache_config: Optional[PeftCacheConfig] = None, scheduler_config: Optional[SchedulerConfig] = None, cache_transceiver_config: Optional[CacheTransceiverConfig] = None, @@ -789,6 +791,21 @@ def create_py_executor_instance( ctx_chunk_config) scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) + if sm_disagg_config is not None: + scheduler_capacity += sm_disagg_config.context_max_batch_size * mapping.pp_size + capacity_scheduler = BindCapacityScheduler( + scheduler_capacity, + kv_cache_manager.impl if kv_cache_manager is not None else None, + peft_cache_manager.impl if peft_cache_manager is not None else None, + scheduler_config.capacity_scheduler_policy, + two_step_lookahead=mapping.has_pp()) + mb_scheduler = BindMicroBatchScheduler( + sm_disagg_config.context_max_batch_size, + sm_disagg_config.context_max_num_tokens, ctx_chunk_config) + ctx_scheduler = SmDisaggCtxScheduler(capacity_scheduler, mb_scheduler) + else: + ctx_scheduler = None + config = model_engine.model.model_config.pretrained_config attention_type = AttentionTypeCpp.MLA if is_mla( config) else AttentionTypeCpp.DEFAULT @@ -801,6 +818,8 @@ def create_py_executor_instance( model_engine=model_engine, sampler=sampler, drafter=drafter, + ctx_scheduler=ctx_scheduler, + ctx_model_engine=ctx_model_engine, dist=dist, max_num_sequences=max_num_sequences, disable_overlap_scheduler=llm_args.disable_overlap_scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 8c506da6b54..5479bdbb858 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -47,10 +47,15 @@ def is_control_request(self): class ExecutorRequestQueue: """Handles fetching and processing of new requests from the request queue.""" - def __init__(self, dist: Distributed, enable_attention_dp: bool, - max_batch_size: int, max_beam_width: int, - max_num_active_requests: int, enable_iter_perf_stats: bool, - batch_wait_timeout_ms: float): + def __init__(self, + dist: Distributed, + enable_attention_dp: bool, + max_batch_size: int, + max_beam_width: int, + max_num_active_requests: int, + enable_iter_perf_stats: bool, + batch_wait_timeout_ms: float, + is_sm_disagg: bool = False): self.dist = dist self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue() self.waiting_queue: deque[RequestQueueItem] = deque() @@ -59,6 +64,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool, self.max_batch_size = max_batch_size self.max_beam_width = max_beam_width self.max_num_active_requests = max_num_active_requests + self.is_sm_disagg = is_sm_disagg self.enqueue_lock = threading.Lock() self.next_request_id = max_batch_size self.enable_iter_perf_stats = enable_iter_perf_stats @@ -333,13 +339,35 @@ def _fetch_and_process_requests( @nvtx_range("_fetch_new_requests") def fetch_new_requests( - self, activate_requests: List[LlmRequest]) -> List[LlmRequest]: + self, activate_requests: List[LlmRequest], + num_active_requests_on_engine: int) -> List[LlmRequest]: - if self.enable_attention_dp: + if self.is_sm_disagg: + return self._fetch_new_requests_sm_disagg( + len(activate_requests), num_active_requests_on_engine) + elif self.enable_attention_dp: return self._fetch_new_requests_attention_dp(activate_requests) else: return self._fetch_new_requests_attention_tp(len(activate_requests)) + def _fetch_new_requests_sm_disagg( + self, num_active_requests: int, + num_active_requests_on_engine: int) -> List[LlmRequest]: + """Handle SM-level disaggregation request fetching.""" + total_max_num_active_requests = (self.max_num_active_requests + + num_active_requests - + num_active_requests_on_engine) + + # fetch and process requests into waiting queue + new_requests = self._fetch_and_process_requests( + num_active_requests_on_engine, + total_max_num_active_requests, + enable_attention_dp=False) + + # Merge requests and add to active list + merged_requests = self._merge_requests(new_requests) + return merged_requests + def _fetch_new_requests_attention_tp( self, num_active_requests: int) -> List[LlmRequest]: """Handle standard (non-attention DP) request fetching.""" diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index c525481fee3..e8a59b0e22d 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -797,3 +797,11 @@ def get_draft_token_length(request: LlmRequest) -> int: if request.py_draft_tokens is not None: return len(request.py_draft_tokens) return 0 + + +def get_context_requests(requests: List[LlmRequest]): + return [req for req in requests if req.is_context_init_state] + + +def get_generation_requests(requests: List[LlmRequest]): + return [req for req in requests if not req.is_context_init_state] diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 42f71181b11..f3471bf3bc5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -135,10 +135,12 @@ def __init__( attn_runtime_features: Optional[AttentionRuntimeFeatures] = None, dist: Optional[MPIDist] = None, spec_config: Optional["DecodingBaseConfig"] = None, + is_sm_disagg_ctx_phase: bool = False, is_draft_model: bool = False, drafting_loop_wrapper: Optional[Callable[[torch.nn.Module], torch.nn.Module]] = None, model: Optional[torch.nn.Module] = None, + weight_sharing_model: Optional[torch.nn.Module] = None, ): self.forward_pass_callable = None self.ub_buffers = None @@ -148,6 +150,9 @@ def __init__( max_seq_len, max_batch_size, ) = llm_args.get_runtime_sizes() + if is_sm_disagg_ctx_phase: + max_num_tokens = llm_args.sm_disagg_config.context_max_num_tokens + max_batch_size = llm_args.sm_disagg_config.context_max_batch_size self.batch_size = max_batch_size self.max_num_tokens = max_num_tokens @@ -165,6 +170,7 @@ def __init__( if dist is not None: ExpertStatistic.create(self.dist.rank) self.llm_args = llm_args + self.sm_disagg_enabled = llm_args.sm_disagg_config is not None self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0 self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0 @@ -195,6 +201,7 @@ def __init__( max_num_tokens=self.max_num_tokens, max_seq_len=self.max_seq_len, lora_config=lora_config, + weight_sharing_model=weight_sharing_model, ) self.model, moe_load_balancer = loader.load( checkpoint_dir=model_path, checkpoint_loader=checkpoint_loader) @@ -1434,8 +1441,10 @@ def _prepare_tp_inputs( # the request has no previous tensor: # (1) next_draft_tokens_device is None, which means overlap scheduler is disabled; or # (2) a dummy request; or - # (3) the first step in the generation server of disaggregated serving - if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None: + # (3) the first step in the generation server of disaggregated serving; or + # (4) the first step in the generation phase of SM-level disaggregation + if next_draft_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ + or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: # get token ids, including input token ids and draft token ids. For these dummy requests, # no need to copy the token ids. if not (request.is_attention_dp_dummy @@ -1559,8 +1568,10 @@ def _prepare_tp_inputs( # the request has no previous tensor: # (1) new_tokens_device is None, which means overlap scheduler is disabled; or # (2) a dummy request; or - # (3) the first step in the generation server of disaggregated serving - if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None: + # (3) the first step in the generation server of disaggregated serving; or + # (4) the first step in the generation phase of SM-level disaggregation + if new_tokens_device is None or request.is_dummy or request.py_batch_idx is None \ + or self.sm_disagg_enabled and request.max_num_generated_tokens == 0: # skip adding input_ids of CUDA graph dummy requests so that new_tokens_device # can be aligned to the correct positions. if not request.is_cuda_graph_dummy: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index b9c1377cd98..329ca42bdac 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -191,7 +191,8 @@ def __init__(self, sparse_attention_config: Optional["SparseAttentionConfig"], max_num_tokens: int, max_seq_len: Optional[int], - lora_config: Optional[LoraConfig] = None): + lora_config: Optional[LoraConfig] = None, + weight_sharing_model: Optional[torch.nn.Module] = None): """ Initializes the ModelLoader. @@ -210,6 +211,7 @@ def __init__(self, self.max_num_tokens = max_num_tokens self.max_seq_len = max_seq_len self.lora_config = lora_config + self.weight_sharing_model = weight_sharing_model def load( self, @@ -307,6 +309,12 @@ def init_meta_tensor(t: torch.Tensor): moe_load_balancer.finalize_model() logger.info("moe_load_balancer finalize model done") + if self.weight_sharing_model is not None: + model.load_state_dict(self.weight_sharing_model.state_dict(), + assign=True) + # Free up duplicate model weights allocated before weight sharing + torch.cuda.empty_cache() + torch.cuda.current_stream().synchronize() return model, moe_load_balancer diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3223d2dc877..443d4e0c2af 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -48,7 +48,8 @@ from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import KvCacheTransceiver from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState, - LlmResponse, get_draft_token_length) + LlmResponse, get_context_requests, + get_draft_token_length, get_generation_requests) from .model_engine import ModelEngine from .resource_manager import ResourceManager from .sampler import Sampler, SampleState, SampleStateTensors @@ -58,6 +59,10 @@ # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." PROFILE_START_STOP_ENV_VAR_NAME = "TLLM_PROFILE_START_STOP" +# Environment variable to switch between context/generation profiling ranges as specified in PROFILE_START_STOP_ENV_VAR_NAME +# Set to "1" for the specified ranges to be treated as context phase ranges +PROFILE_SM_DISAGG_CTX_RANGE_ENV_VAR_NAME = "TLLM_PROFILE_SM_DISAGG_CTX_RANGE" + # Environment variable to enable PyTorch profiler tracing. # Set to a path to save detailed tracing of PyTorch operations. PROFILE_TRACE_ENV_VAR_NAME = "TLLM_TORCH_PROFILE_TRACE" @@ -117,6 +122,8 @@ def __init__(self, dist: Distributed, max_num_sequences: int, drafter: Optional[Drafter] = None, + ctx_scheduler: Optional[RequestScheduler] = None, + ctx_model_engine: Optional[ModelEngine] = None, disable_overlap_scheduler: bool = False, max_input_len: int = 2048, max_batch_size: int = 8, @@ -155,6 +162,11 @@ def __init__(self, self.disable_overlap_scheduler = disable_overlap_scheduler self.virtual_memory_pools = virtual_memory_pools + self.ctx_scheduler = ctx_scheduler + self.ctx_model_engine = ctx_model_engine + self.profile_sm_disagg_ctx_range = bool( + os.environ.get(PROFILE_SM_DISAGG_CTX_RANGE_ENV_VAR_NAME, None)) + # enqueue and _fetch_new_requests used data self.active = True self.max_beam_width = max_beam_width @@ -172,6 +184,8 @@ def __init__(self, if self.attention_dp_enable_balance: self.attention_dp_time_out_iters = self.llm_args.attention_dp_config.timeout_iters self.attention_dp_batching_wait_iters = self.llm_args.attention_dp_config.batching_wait_iters + if self.ctx_model_engine is not None: + self.sm_disagg_ctx_sm_percent = self.llm_args.sm_disagg_config.context_sm_percent self.batch_wait_timeout_ms = self.llm_args.batch_wait_timeout_ms self.batch_wait_timeout_iters = self.llm_args.batch_wait_timeout_iters self.batch_wait_max_tokens_ratio = self.llm_args.batch_wait_max_tokens_ratio @@ -201,6 +215,10 @@ def __init__(self, self.responses = {} self.result_wait_queues = {} + self.sm_disagg_request_lock = threading.Lock() + self.ctx_request_cv = threading.Condition(self.sm_disagg_request_lock) + self.gen_request_cv = threading.Condition(self.sm_disagg_request_lock) + # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) @@ -234,9 +252,12 @@ def __init__(self, # During warmup, we don't enable the profiler self.is_warmup = True - self.model_engine.warmup(self.resource_manager) - if self.draft_model_engine is not None: - self.draft_model_engine.warmup(self.resource_manager) + for eng in [ + self.model_engine, self.ctx_model_engine, + self.draft_model_engine + ]: + if eng is not None: + eng.warmup(self.resource_manager) self.is_warmup = False self.is_shutdown = False @@ -255,6 +276,7 @@ def __init__(self, max_num_active_requests=self.max_num_active_requests, enable_iter_perf_stats=self.enable_iter_perf_stats, batch_wait_timeout_ms=self.batch_wait_timeout_ms, + is_sm_disagg=ctx_model_engine is not None, ) self.executor_request_queue.set_exclude_last_generation_logits( self.disable_overlap_scheduler, self.dist.pp_size) @@ -275,6 +297,8 @@ def __init__(self, if self.dist.pp_size > 1: self.event_loop = self._executor_loop_pp + elif ctx_model_engine is not None: + self.event_loop = self._executor_loop_sm_disagg else: self.event_loop = self._executor_loop if self.disable_overlap_scheduler else self._executor_loop_overlap if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): @@ -346,9 +370,12 @@ def is_warmup(self) -> bool: def is_warmup(self, value: bool): self._is_warmup = value # Set warmup flag in model engine to trigger torch compile and avoid moe load balancer statistics update - self.model_engine.is_warmup = value - if self.draft_model_engine is not None: - self.draft_model_engine.is_warmup = value + for eng in [ + self.model_engine, self.ctx_model_engine, + self.draft_model_engine + ]: + if eng is not None: + eng.is_warmup = value def start_worker(self): with self.worker_lock: @@ -443,6 +470,8 @@ def shutdown(self): if manager: manager.shutdown() del self.model_engine + if self.ctx_model_engine is not None: + del self.ctx_model_engine if self.draft_model_engine is not None: del self.draft_model_engine if self.virtual_memory_pools is not None: @@ -505,7 +534,14 @@ def should_stop_processing(self): self.executor_request_queue.get_waiting_queue_size() == 0 @contextmanager - def _profiler(self): + def _profiler(self, + model_engine: Optional[ModelEngine] = None, + stream: Optional[torch.cuda.Stream] = None, + phase_name: Optional[str] = None, + enable_profiler: bool = True): + if model_engine is None: + model_engine = self.model_engine + it = -1 enabled = False start_time = None @@ -522,7 +558,8 @@ def _profiler(self): torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None) profile_start_stop = os.environ.get(PROFILE_START_STOP_ENV_VAR_NAME, None) - enable_torch_trace = bool(torch_trace_path and profile_start_stop) + enable_torch_trace = bool(enable_profiler and torch_trace_path + and profile_start_stop) if torch_trace_path and profile_start_stop is None: logger.warning( f"{PROFILE_START_STOP_ENV_VAR_NAME} environment variable " @@ -542,7 +579,7 @@ def _profiler(self): def profile_step(): nonlocal it, enabled, start_time, start_event_1, end_event_1, start_event_2, end_event_2, prev_device_step_time - if it in self.profile_stop_iters and not self.is_warmup: + if it in self.profile_stop_iters and enable_profiler and not self.is_warmup: assert enabled, "Inconsistent CUDA profiling state" if enable_torch_trace: torch_profiler.stop() @@ -554,18 +591,19 @@ def profile_step(): if start_time is not None and self.print_log and self.dist.rank == 0: end_time = time.time() - if it % 2 == 0: - end_event_1.record() - if start_event_2 is not None: - end_event_2.synchronize() - prev_device_step_time = start_event_2.elapsed_time( - end_event_2) - else: - end_event_2.record() - if start_event_1 is not None: - end_event_1.synchronize() - prev_device_step_time = start_event_1.elapsed_time( - end_event_1) + with torch.cuda.stream(stream): + if it % 2 == 0: + end_event_1.record() + if start_event_2 is not None: + end_event_2.synchronize() + prev_device_step_time = start_event_2.elapsed_time( + end_event_2) + else: + end_event_2.record() + if start_event_1 is not None: + end_event_1.synchronize() + prev_device_step_time = start_event_1.elapsed_time( + end_event_1) if prev_device_step_time is None: prev_device_step_time = "N/A" # Handle first iteration @@ -575,7 +613,8 @@ def profile_step(): formatted_timestamp = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") logger.info( - f"iter = {self.model_engine.iter_counter}, " + (f"phase = {phase_name}, " if phase_name else "") + + f"iter = {model_engine.iter_counter}, " f"global_rank = {self.global_rank}, " f"rank = {self.dist.rank}, " f"currank_total_requests = {self.executor_request_queue.num_fetch_requests_cur_rank}/" @@ -584,11 +623,11 @@ def profile_step(): f"prev_device_step_time = {prev_device_step_time}, " f"timestamp = {formatted_timestamp}, " f"num_scheduled_requests: {self.num_scheduled_requests}, " - f"states = {self.model_engine.iter_states}") + f"states = {model_engine.iter_states}") it += 1 - if it in self.profile_start_iters and not self.is_warmup: + if it in self.profile_start_iters and enable_profiler and not self.is_warmup: assert not enabled, "Inconsistent CUDA profiling state" torch.cuda.cudart().cudaProfilerStart() if enable_torch_trace: @@ -596,14 +635,15 @@ def profile_step(): logger.info(f"Profiling started at iteration {it}.") enabled = True start_time = time.time() - if it % 2 == 0: - if start_event_1 is None: - start_event_1 = torch.cuda.Event(enable_timing=True) - start_event_1.record() - else: - if start_event_2 is None: - start_event_2 = torch.cuda.Event(enable_timing=True) - start_event_2.record() + with torch.cuda.stream(stream): + if it % 2 == 0: + if start_event_1 is None: + start_event_1 = torch.cuda.Event(enable_timing=True) + start_event_1.record() + else: + if start_event_2 is None: + start_event_2 = torch.cuda.Event(enable_timing=True) + start_event_2.record() try: yield profile_step @@ -705,8 +745,6 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests, stats.cpu_mem_usage = 0 stats.pinned_mem_usage = 0 - stats.iter = self.model_engine.iter_counter - kv_cache_manager = self.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER) if kv_cache_manager is not None: @@ -888,6 +926,7 @@ def _executor_loop_pp(self): self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: + iter_stats.iter_counter = self.model_engine.iter_counter iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] batch_state = BatchStatePP( @@ -1232,6 +1271,7 @@ def _executor_loop(self): self._kv_connector_terminate_requests() if self.enable_iter_perf_stats and sample_state is not None: + iter_stats.iter_counter = self.model_engine.iter_counter iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] self._process_iter_stats( @@ -1455,6 +1495,7 @@ def _executor_loop_overlap(self): self._process_previous_batch() if self.enable_iter_perf_stats: + iter_stats.iter_counter = self.model_engine.iter_counter iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ 'num_ctx_tokens'] @@ -1470,11 +1511,16 @@ def _executor_loop_overlap(self): self._kv_connector_terminate_requests() - def _executor_loop_sm_disagg_ctx(self): + def _executor_loop_sm_disagg_ctx(self, stream): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + with self._profiler( + model_engine=self.ctx_model_engine, + stream=stream, + phase_name='context', + enable_profiler=self.profile_sm_disagg_ctx_range, + ) as profile_step: sample_state = None iter_start_time = time.time() iter_stats = None @@ -1483,105 +1529,68 @@ def _executor_loop_sm_disagg_ctx(self): if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() - if scheduled_batch is None: - break + new_requests = self._fetch_and_activate_new_requests() - self._pause_requests(scheduled_batch.paused_requests) + if self.enable_iter_perf_stats: + iter_stats = self._get_init_iter_stats( + len(new_requests), + self.executor_request_queue. + get_new_active_requests_queue_latency()) - finished_requests = [] + with self.sm_disagg_request_lock: + ctx_requests = get_context_requests(self.active_requests) + if self.is_shutdown and len(ctx_requests) == 0 \ + and self.executor_request_queue.get_waiting_queue_size() == 0: + self.ctx_request_cv.notify() + break - if scheduled_batch.batch_size > 0 or ( - self.enable_attention_dp and self.dist.tp_size > 1): - if self.kv_cache_transceiver: - # For generation requests which have completed KV cache transfer - self._prepare_disagg_gen_transmission_complete( - scheduled_batch) + scheduled_batch, _, _ = self._schedule( + scheduler=self.ctx_scheduler) - # Return the first token to the client - self._handle_first_token_response(scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch.context_requests)} context requests' + ) + if scheduled_batch.batch_size == 0 \ + and (len(ctx_requests) > 0 or self.executor_request_queue.get_waiting_queue_size() > 0): + self.gen_request_cv.wait() + continue - self._kv_connector_start_batch(scheduled_batch) + finished_requests = [] if scheduled_batch.batch_size > 0 or ( self.enable_attention_dp and self.dist.tp_size > 1): - # init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers. - # init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated. - if self.guided_decoder is not None: - self.guided_decoder.add_batch(scheduled_batch) - if self.kv_cache_transceiver: - self.guided_decoder.init_disagg_gen_requests() - - if self.drafter is not None and self.use_spec_decode: - if self.guided_decoder is not None: - self.guided_decoder.rollback_rejected_tokens() - with request_context( - is_draft=self.draft_model_engine is not None, - scheduled_requests=scheduled_batch): - self.drafter.prepare_draft_tokens( - scheduled_batch, self.resource_manager) - # Pad draft tokens to the max draft length. This is for CUDA graph compatibility. - self.drafter.pad_draft_tokens_for_cuda_graph( - scheduled_batch) - # add_batch must be called again to restore to target requests with updated draft tokens. - if self.guided_decoder is not None: - self.guided_decoder.add_batch(scheduled_batch) - if hasattr(self.drafter, "guided_decoder"): - self.guided_decoder.rollback_draft_tokens() - - batch_outputs = self._forward_step(scheduled_batch) - if self.guided_decoder is not None: - self.guided_decoder.execute(batch_outputs['logits']) - - sample_state = self._sample_async(scheduled_batch, - batch_outputs) - if self.drafter is not None: - self.drafter.run_drafter_post(scheduled_batch, - self.resource_manager, - self.is_warmup) - - self._update_request_states(scheduled_batch) - self._update_requests(sample_state, self.resource_manager) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in scheduled_batch.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) + self.resource_manager.prepare_resources(scheduled_batch) - if self.kv_cache_transceiver: - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests) - # For context only req in transmission, we reset the state since sampler might have changed it - for req in ctx_transmission_reqs: - req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + with torch.cuda.stream(stream): + batch_outputs = self._forward_step( + scheduled_batch, model_engine=self.ctx_model_engine) + sample_state = self._sample_async( + scheduled_batch, batch_outputs) + # To avoid long sync time in critical section below + sample_state.sampler_event.synchronize() + + with self.sm_disagg_request_lock: + self._update_request_states(scheduled_batch) + self._update_requests(sample_state, + self.resource_manager) + self._handle_canceled_requests() + finished_requests = self._handle_responses() + self.ctx_request_cv.notify() - self._handle_canceled_requests() - finished_requests = self._handle_responses() - attn_metadata = getattr(self.model_engine, 'attn_metadata', - None) + attn_metadata = getattr(self.ctx_model_engine, + 'attn_metadata', None) kv_cache_dtype_byte_size = getattr( - self.model_engine, 'kv_cache_dtype_byte_size', None) + self.ctx_model_engine, 'kv_cache_dtype_byte_size', None) self.resource_manager.update_resources( scheduled_batch, attn_metadata, kv_cache_dtype_byte_size) if self.enable_kv_cache_events: self._add_kv_cache_events() - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: - self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() - - self._kv_connector_terminate_requests() - if self.enable_iter_perf_stats and sample_state is not None: - iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + iter_stats.iter_counter = self.ctx_model_engine.iter_counter + iter_stats.inflight_batching_stats.num_ctx_tokens = self.ctx_model_engine.iter_states[ 'num_ctx_tokens'] self._process_iter_stats( finished_requests, self.active_requests, @@ -1589,158 +1598,144 @@ def _executor_loop_sm_disagg_ctx(self): iter_stats=iter_stats, iter_start_time=iter_start_time)) - def _executor_loop_sm_disagg_gen_overlap(self): + def _executor_loop_sm_disagg_gen_overlap(self, stream): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + with self._profiler( + stream=stream, + phase_name='generation', + enable_profiler=not self.profile_sm_disagg_ctx_range, + ) as profile_step: iter_start_time = time.time() iter_stats = None - can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True while True: profile_step() if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() - if scheduled_batch is None: + if self.should_stop_processing: break - # In gen-only benchmarking mode, wait until the number of scheduled generation - # requests reaches the required threshold before starting forward pass, - # to ensure consistent batch sizes for accurate performance measurement. - if not self.is_warmup and not can_forward: - if self.enable_attention_dp: - local_can_forward = self.executor_request_queue.num_fetch_requests + \ - len(scheduled_batch.generation_requests) >= self.benchmark_req_queues_size - all_can_forward = self.dist.tp_allgather( - local_can_forward) - if all(all_can_forward): - can_forward = True - time.sleep(10) - else: - if self.dist.rank == 0: - logger.info( - f"sleep 10 seconds, num_fetched_requests: {self.executor_request_queue.num_fetch_requests}, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}" - ) - time.sleep(10) - continue - else: - if len(scheduled_batch.generation_requests - ) < self.benchmark_req_queues_size: - if self.dist.rank == 0: - logger.info( - f"sleep 10 seconds, scheduled_gen_batch: {len(scheduled_batch.generation_requests)}" - ) - time.sleep(10) - continue - else: - can_forward = True - self._pause_requests(scheduled_batch.paused_requests) + if self.enable_iter_perf_stats: + iter_stats = self._get_init_iter_stats( + num_new_active_requests=0, + new_active_requests_queue_latency_ms=0) - if scheduled_batch.batch_size > 0: - if self.kv_cache_transceiver: - # For generation requests which have completed KV cache transfer - self._prepare_disagg_gen_transmission_complete( - scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) + with self.sm_disagg_request_lock: + self._pad_attention_dp_dummy_request() - self._kv_connector_start_batch(scheduled_batch) + gen_requests = get_generation_requests(self.active_requests) + scheduled_batch, _, _ = self._schedule( + active_requests=gen_requests) + + self.num_scheduled_requests = scheduled_batch.batch_size + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch.generation_requests)} generation requests' + ) + + if scheduled_batch.batch_size == 0: + self.ctx_request_cv.wait() + continue + + self._pause_requests(scheduled_batch.paused_requests) if scheduled_batch.batch_size > 0: + self.resource_manager.prepare_resources(scheduled_batch) - # The generation requests that are do not have batch_idx, + # The generation requests that just finished context phase # needs to be in front of the batch due to the assumptions - # made in model_engine.py::_forward_step. This is only important - # for disaggregated serving. For non-disaggregated serving, - # the generation requests always have batch_idx. + # made in model_engine.py::_forward_step. scheduled_batch.generation_requests = sorted( # stable sort scheduled_batch.generation_requests, - key=lambda req: int(req.py_batch_idx is not None), + key=lambda req: int(req.max_num_generated_tokens > 0), ) - if self.kv_cache_transceiver: - # Return the first token to the client - self._handle_first_token_response(scheduled_batch) + previous_tensors_device = self.previous_batch and self.previous_batch.sample_state \ + and self.previous_batch.sample_state.device - # init_disagg_gen_requests must be before engine forward, where the prev_seq_slot is updated. - if self.guided_decoder is not None and self.kv_cache_transceiver: - self.guided_decoder.add_batch(scheduled_batch) - self.guided_decoder.init_disagg_gen_requests() + with torch.cuda.stream(stream): + batch_outputs = self._forward_step( + scheduled_batch, previous_tensors_device) + # To avoid long sync time in critical section below + if self.previous_batch is not None: + self.previous_batch.sample_state.sampler_event.synchronize( + ) - previous_tensors = self.previous_batch and self.previous_batch.sample_state - target_inputs = None - draft_outputs = None - # If there are previous draft tokens, we need to update the target requests to accept some draft tokens. - # When there's any accepted tokens, we can't directly use the previous batch's outputs in this iteration for the target model, - # so we'll set the target model's input to None and skip updating the target requests after target model forward. - use_previous_draft_tokens = self.has_previous_draft_tokens - if self.drafter is not None and (self.use_spec_decode or - use_previous_draft_tokens): - target_inputs, draft_outputs, draft_batch = self._handle_speculative_decoding( - scheduled_batch, previous_tensors) + with self.sm_disagg_request_lock: + if self.previous_batch is not None: + self._update_requests( + self.previous_batch.sample_state) - # Use the draft_model's outputs if we've launched the draft model. - # Otherwise, use the previous batch's outputs. - if target_inputs is not None or use_previous_draft_tokens: - previous_tensors_device = target_inputs - else: - previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device + with torch.cuda.stream(stream): + sample_state = self._sample_async( + scheduled_batch, batch_outputs) + assert sample_state is not None, "Sampling failed" - batch_outputs = self._forward_step(scheduled_batch, - previous_tensors_device) + with self.sm_disagg_request_lock: + self._update_request_states(scheduled_batch) - if target_inputs is not None: - self._process_draft_results(scheduled_batch, - draft_outputs, draft_batch) - elif self.previous_batch is not None and not use_previous_draft_tokens: - self._update_requests(self.previous_batch.sample_state) + if self.previous_batch is not None: + self._process_previous_batch() - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in self.previous_batch.sample_state.scheduled_requests.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) - - if self.guided_decoder is not None: - # add_batch must be called again to have updated new tokens. - self.guided_decoder.add_batch(scheduled_batch) - self.guided_decoder.execute(batch_outputs['logits']) - - sample_state = self._sample_async(scheduled_batch, - batch_outputs) - assert sample_state is not None, "Sampling failed" - - self._update_request_states(scheduled_batch) - - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests - ) if self.kv_cache_transceiver else [] - - if self.previous_batch is not None: - self._process_previous_batch() + self.gen_request_cv.notify() if self.enable_iter_perf_stats: - iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ - 'num_ctx_tokens'] + iter_stats.iter_counter = self.model_engine.iter_counter self.previous_batch = BatchState( sample_state=sample_state, iter_start_time=iter_start_time, - iter_stats=iter_stats, - ctx_transmission_reqs=ctx_transmission_reqs) + iter_stats=iter_stats) - if self.kv_cache_transceiver and self.ctx_in_transmission_requests: - self._check_kv_transfer_timeout() - self._terminate_disagg_ctx_finished_requests() + def _executor_loop_sm_disagg(self): + stream_ctx, stream_gen = self.split_device_green_ctx() - self._kv_connector_terminate_requests() + thread_ctx = threading.Thread(target=self._executor_loop_sm_disagg_ctx, + args=(stream_ctx, ), + daemon=True) + thread_ctx.start() + + self._executor_loop_sm_disagg_gen_overlap(stream_gen) + + thread_ctx.join() + + def split_device_green_ctx(self): + device = torch.device("cuda", self.device_id) + device_properties = torch.cuda.get_device_properties(device) + sm_count = device_properties.multi_processor_count + if device_properties.major >= 9: + sm_min = 8 + sm_align = 8 + else: + sm_min = 4 if device_properties.major == 8 else 2 + sm_align = 2 + + from flashinfer import green_ctx + + def split_device_green_ctx_aligned(sm_s1): + sm_s1 = round(sm_s1 / sm_align) * sm_align + sm_s1 = min(max(sm_s1, sm_min), sm_count - sm_min) + return green_ctx.split_device_green_ctx_by_sm_count(device, [sm_s1]) + + sm_ctx = round(sm_count * self.sm_disagg_ctx_sm_percent) + sm_gen = sm_count - sm_ctx + # Choose the split closer to user-specified percentage when sm_count is not divisible by sm_align + sm_ctx_dist = min(sm_ctx % sm_align, sm_align - (sm_ctx % sm_align)) + sm_gen_dist = min(sm_gen % sm_align, sm_align - (sm_gen % sm_align)) + if sm_gen_dist < sm_ctx_dist: + (stream_gen, + stream_ctx), (res_gen, + res_ctx) = split_device_green_ctx_aligned(sm_gen) + else: + (stream_ctx, + stream_gen), (res_ctx, + res_gen) = split_device_green_ctx_aligned(sm_ctx) + logger.info( + f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase." + ) + return stream_ctx, stream_gen def _accept_draft_tokens( self, scheduled_batch: ScheduledRequests, @@ -1901,8 +1896,13 @@ def _respond_if_invalid(request: LlmRequest) -> bool: self._handle_errors(str(e), requests=[request]) return True + if self.ctx_model_engine is not None: + num_active_requests_on_engine = len( + get_context_requests(self.active_requests)) + else: + num_active_requests_on_engine = len(self.active_requests) new_requests_cur_rank = self.executor_request_queue.fetch_new_requests( - self.active_requests) + self.active_requests, num_active_requests_on_engine) self.is_shutdown = self.executor_request_queue.is_shutdown self.expected_num_active_requests = self.executor_request_queue.get_expected_num_active_requests( ) @@ -1996,9 +1996,15 @@ def _waiting_requests(self, context_requests: list[LlmRequest], return waited_context_requests @nvtx_range("_schedule") - def _schedule(self): - scheduler_output = self.scheduler.schedule_request( - self.active_requests, self.inflight_req_ids) + def _schedule(self, + scheduler: Optional[RequestScheduler] = None, + active_requests: Optional[List[LlmRequest]] = None): + if scheduler is None: + scheduler = self.scheduler + if active_requests is None: + active_requests = self.active_requests + scheduler_output = scheduler.schedule_request(active_requests, + self.inflight_req_ids) scheduled_context_requests = scheduler_output.context_requests if self.enable_attention_dp and self.attention_dp_enable_balance: scheduled_context_requests = self._balance_adp_requests( @@ -2232,14 +2238,17 @@ def _check_disagg_gen_cache_transfer_status(self, atLeastNum: int = 0): def _forward_step(self, scheduled_requests, - new_tensors_device: Optional[SampleStateTensors] = None): + new_tensors_device: Optional[SampleStateTensors] = None, + model_engine: Optional[ModelEngine] = None): + if model_engine is None: + model_engine = self.model_engine @nvtx_range( - f"[Executor] _forward_step {self.model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" + f"[Executor] _forward_step {model_engine.iter_counter + 1}: {len(scheduled_requests.context_requests)} ctx reqs, {len(scheduled_requests.generation_requests)} gen reqs" ) def forward(scheduled_requests, resource_manager, new_tensors_device, gather_context_logits, cache_indirection_buffer): - return self.model_engine.forward( + return model_engine.forward( scheduled_requests, resource_manager, new_tensors_device, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 008efc0405c..8dd736cf09f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -79,6 +79,8 @@ def _bytes_to_gib(bytes: int) -> float: "KV cache", ExecutorMemoryType.MODEL_ENGINE_MAIN: "Model", + ExecutorMemoryType.MODEL_ENGINE_CTX: + "Context model for SM-level disaggregation", ExecutorMemoryType.MODEL_ENGINE_DRAFT: "Draft model for speculative decoding", } @@ -96,6 +98,9 @@ def _bytes_to_gib(bytes: int) -> float: ExecutorMemoryType.INIT_KV_CACHE: "reduce max_num_tokens", ExecutorMemoryType.MODEL_ENGINE_MAIN: + ("reduce max_num_tokens and/or shard the model weights across GPUs by enabling " + "pipeline and/or tensor parallelism"), + ExecutorMemoryType.MODEL_ENGINE_CTX: ("reduce max_num_tokens and/or shard the model weights across GPUs by enabling " "pipeline and/or tensor parallelism"), ExecutorMemoryType.MODEL_ENGINE_DRAFT: @@ -345,6 +350,40 @@ def allocation_scope(current_stage: ExecutorMemoryType, validate_feature_combination(llm_args, model_engine, llm_args.sampler_type) + if llm_args.sm_disagg_config is not None: + with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX, + RestoreMode.PINNED): + ctx_llm_args = copy.copy(llm_args) + ctx_llm_args.cuda_graph_config = None + ctx_model_engine = PyTorchModelEngine( + model_path=checkpoint_dir, + llm_args=ctx_llm_args, + mapping=mapping, + attn_runtime_features=attn_runtime_features, + dist=dist, + spec_config=spec_config, + weight_sharing_model=model_engine.model, + ) + else: + ctx_model_engine = None + + if llm_args.sm_disagg_config is not None: + with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX, + RestoreMode.PINNED): + ctx_backend_config = copy.copy(pytorch_backend_config) + ctx_backend_config.use_cuda_graph = False + ctx_model_engine = PyTorchModelEngine( + model_path=checkpoint_dir, + llm_args=llm_args, + mapping=mapping, + attn_runtime_features=attn_runtime_features, + dist=dist, + spec_config=spec_config, + weight_sharing_model=model_engine.model, + ) + else: + ctx_model_engine = None + if has_draft_model_engine: with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_DRAFT, RestoreMode.PINNED): @@ -447,9 +486,9 @@ def drafting_loop_wrapper(model): "Chunked Prefill for MLA can only be enabled on SM90/SM100/SM103/SM120, " f"disable enable_chunked_context for SM{sm_version}") enable_chunked_context = False - model_engine.attn_runtime_features.chunked_prefill = False - if draft_model_engine is not None: - draft_model_engine.attn_runtime_features.chunked_prefill = False + for eng in [model_engine, ctx_model_engine, draft_model_engine]: + if eng is not None: + eng.attn_runtime_features.chunked_prefill = False if enable_chunked_context: chunk_unit_size = tokens_per_block @@ -642,6 +681,8 @@ def drafting_loop_wrapper(model): max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_num_tokens=max_num_tokens, + ctx_model_engine=ctx_model_engine, + sm_disagg_config=llm_args.sm_disagg_config, peft_cache_config=peft_cache_config, scheduler_config=scheduler_config, cache_transceiver_config=cache_transceiver_config, @@ -670,11 +711,11 @@ def drafting_loop_wrapper(model): max_seq_len = kv_cache_creator._max_seq_len update_sampler_max_seq_len(max_seq_len, sampler) - for eng in [model_engine, draft_model_engine]: + for eng in [model_engine, ctx_model_engine, draft_model_engine]: if eng is None: continue if eng.attn_metadata is not None: - if llm_args.cuda_graph_config is not None: + if eng.cuda_graph_runner.enabled: eng._release_cuda_graphs() eng.attn_metadata = None @@ -699,6 +740,8 @@ def drafting_loop_wrapper(model): max_batch_size=max_batch_size, max_beam_width=max_beam_width, max_num_tokens=max_num_tokens, + ctx_model_engine=ctx_model_engine, + sm_disagg_config=llm_args.sm_disagg_config, peft_cache_config=peft_cache_config, scheduler_config=scheduler_config, cache_transceiver_config=cache_transceiver_config, diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index c71c4596ed7..e62149b276f 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -7,7 +7,7 @@ from tensorrt_llm.bindings import internal as tb_internal from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy -from .llm_request import LlmRequest, LlmRequestState +from .llm_request import LlmRequest, LlmRequestState, get_context_requests RequestList = list[LlmRequest] @@ -216,3 +216,28 @@ def schedule_request(self, active_requests: RequestList, list(generation_requests), list(paused_requests), list(fitting_disagg_gen_init_requests), len(fitting_requests)) + + +class SmDisaggCtxScheduler(RequestScheduler): + + def __init__(self, capacity_scheduler: CapacityScheduler, + micro_batch_scheduler: MicroBatchScheduler): + super(SmDisaggCtxScheduler, self).__init__() + self.capacity_scheduler = capacity_scheduler + self.micro_batch_scheduler = micro_batch_scheduler + + def schedule_request(self, active_requests: RequestList, + inflight_request_ids: set[int]) -> SchedulerOutput: + fitting_requests, fitting_disagg_gen_init_requests, paused_requests = self.capacity_scheduler.schedule_request( + active_requests) + + fitting_requests = get_context_requests(fitting_requests) + + context_requests, generation_requests = self.micro_batch_scheduler.schedule( + fitting_requests, inflight_request_ids) + # Convert from binding type RequestVector to list[LlmRequest], + # so Python fields on LlmRequest won't be stripped away + return SchedulerOutput(list(context_requests), + list(generation_requests), list(paused_requests), + list(fitting_disagg_gen_init_requests), + len(fitting_requests)) diff --git a/tensorrt_llm/_torch/virtual_memory.py b/tensorrt_llm/_torch/virtual_memory.py index 3702d732539..4ab9c21a214 100644 --- a/tensorrt_llm/_torch/virtual_memory.py +++ b/tensorrt_llm/_torch/virtual_memory.py @@ -78,6 +78,7 @@ class ExecutorMemoryType(StrEnum): EXTRA_RESOURCES = "executor_extra" KV_CACHE = "kv_cache" MODEL_ENGINE_MAIN = "model" + MODEL_ENGINE_CTX = "ctx_model" MODEL_ENGINE_DRAFT = "draft_model" diff --git a/tensorrt_llm/llmapi/__init__.py b/tensorrt_llm/llmapi/__init__.py index cb868d8d068..9c2459eeda9 100644 --- a/tensorrt_llm/llmapi/__init__.py +++ b/tensorrt_llm/llmapi/__init__.py @@ -14,8 +14,8 @@ MedusaDecodingConfig, MoeConfig, MTPDecodingConfig, NGramDecodingConfig, RocketSparseAttentionConfig, SaveHiddenStatesDecodingConfig, SchedulerConfig, - TorchCompileConfig, TorchLlmArgs, TrtLlmArgs, - UserProvidedDecodingConfig) + SmDisaggConfig, TorchCompileConfig, TorchLlmArgs, + TrtLlmArgs, UserProvidedDecodingConfig) from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo, QuantConfig) from .mm_encoder import MultimodalEncoder @@ -62,6 +62,7 @@ 'AttentionDpConfig', 'LoRARequest', 'SaveHiddenStatesDecodingConfig', + 'SmDisaggConfig', 'RocketSparseAttentionConfig', 'DeepSeekSparseAttentionConfig', ] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 40d967a2cb8..4554189b4c3 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -424,6 +424,29 @@ def from_dict(cls, data: dict): return cls(**data) +class SmDisaggConfig(StrictBaseModel): + """ + Configuration for SM-level disaggregation. + """ + context_sm_percent: float = Field( + default=0.5, + description="Percentage of SMs allocated to context phase.") + context_max_num_tokens: int = Field( + default=0, + description= + "The maximum number of tokens for context phase. If less than or equal to 0, the same value as generation phase is used." + ) + context_max_batch_size: int = Field( + default=0, + description= + "The maximum batch size for context phase. If less than or equal to 0, the same value as generation phase is used." + ) + + @classmethod + def from_dict(cls, data: dict): + return cls(**data) + + class _ParallelConfig(StrictBaseModel): """The model distribution configs for LLM.""" tp_size: int = 1 @@ -2556,6 +2579,11 @@ class TorchLlmArgs(BaseLlmArgs): description="Disable the overlap scheduler.", status="beta") + sm_disagg_config: Optional[SmDisaggConfig] = Field( + default=None, + description="SM-level disaggregation config.", + status="prototype") + moe_config: MoeConfig = Field(default_factory=MoeConfig, description="MoE config.", status="beta") @@ -2914,6 +2942,23 @@ def validate_attention_dp_config(self) -> 'TorchLlmArgs': ) return self + @model_validator(mode='after') + def validate_and_sync_sm_disagg_config(self) -> 'TorchLlmArgs': + """Validate SM-level disaggregation configuration.""" + if self.sm_disagg_config is None: + return self + + config = self.sm_disagg_config + if not 0 < config.context_sm_percent < 1: + raise ValueError( + "sm_disagg_config.context_sm_percent must be in the range (0, 1)" + ) + if config.context_max_num_tokens <= 0: + config.context_max_num_tokens = self.max_num_tokens + if config.context_max_batch_size <= 0: + config.context_max_batch_size = self.max_batch_size + return self + @model_validator(mode='after') def validate_batch_wait_timeout_ms(self) -> 'TorchLlmArgs': """Validate batch wait timeout.""" From 224c8ed9b0d15ceba54d1913772234fa8bf6e75f Mon Sep 17 00:00:00 2001 From: Qiang Xu Date: Thu, 6 Nov 2025 20:56:13 +0000 Subject: [PATCH 3/6] avoid kv cache manager race conditions Signed-off-by: Qiang Xu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 44 ++++++++++--------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 443d4e0c2af..49c44b8a447 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -215,9 +215,9 @@ def __init__(self, self.responses = {} self.result_wait_queues = {} - self.sm_disagg_request_lock = threading.Lock() - self.ctx_request_cv = threading.Condition(self.sm_disagg_request_lock) - self.gen_request_cv = threading.Condition(self.sm_disagg_request_lock) + self.sm_disagg_lock = threading.Lock() + self.ctx_request_cv = threading.Condition(self.sm_disagg_lock) + self.gen_request_cv = threading.Condition(self.sm_disagg_lock) # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( @@ -1537,7 +1537,7 @@ def _executor_loop_sm_disagg_ctx(self, stream): self.executor_request_queue. get_new_active_requests_queue_latency()) - with self.sm_disagg_request_lock: + with self.sm_disagg_lock: ctx_requests = get_context_requests(self.active_requests) if self.is_shutdown and len(ctx_requests) == 0 \ and self.executor_request_queue.get_waiting_queue_size() == 0: @@ -1560,7 +1560,8 @@ def _executor_loop_sm_disagg_ctx(self, stream): if scheduled_batch.batch_size > 0 or ( self.enable_attention_dp and self.dist.tp_size > 1): - self.resource_manager.prepare_resources(scheduled_batch) + with self.sm_disagg_lock: + self.resource_manager.prepare_resources(scheduled_batch) with torch.cuda.stream(stream): batch_outputs = self._forward_step( @@ -1570,23 +1571,25 @@ def _executor_loop_sm_disagg_ctx(self, stream): # To avoid long sync time in critical section below sample_state.sampler_event.synchronize() - with self.sm_disagg_request_lock: + with self.sm_disagg_lock: self._update_request_states(scheduled_batch) self._update_requests(sample_state, self.resource_manager) self._handle_canceled_requests() finished_requests = self._handle_responses() - self.ctx_request_cv.notify() - attn_metadata = getattr(self.ctx_model_engine, - 'attn_metadata', None) - kv_cache_dtype_byte_size = getattr( - self.ctx_model_engine, 'kv_cache_dtype_byte_size', None) - self.resource_manager.update_resources( - scheduled_batch, attn_metadata, - kv_cache_dtype_byte_size) - if self.enable_kv_cache_events: - self._add_kv_cache_events() + attn_metadata = getattr(self.ctx_model_engine, + 'attn_metadata', None) + kv_cache_dtype_byte_size = getattr( + self.ctx_model_engine, 'kv_cache_dtype_byte_size', + None) + self.resource_manager.update_resources( + scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + if self.enable_kv_cache_events: + self._add_kv_cache_events() + + self.ctx_request_cv.notify() if self.enable_iter_perf_stats and sample_state is not None: iter_stats.iter_counter = self.ctx_model_engine.iter_counter @@ -1622,7 +1625,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream): num_new_active_requests=0, new_active_requests_queue_latency_ms=0) - with self.sm_disagg_request_lock: + with self.sm_disagg_lock: self._pad_attention_dp_dummy_request() gen_requests = get_generation_requests(self.active_requests) @@ -1642,7 +1645,8 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream): self._pause_requests(scheduled_batch.paused_requests) if scheduled_batch.batch_size > 0: - self.resource_manager.prepare_resources(scheduled_batch) + with self.sm_disagg_lock: + self.resource_manager.prepare_resources(scheduled_batch) # The generation requests that just finished context phase # needs to be in front of the batch due to the assumptions @@ -1663,7 +1667,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream): self.previous_batch.sample_state.sampler_event.synchronize( ) - with self.sm_disagg_request_lock: + with self.sm_disagg_lock: if self.previous_batch is not None: self._update_requests( self.previous_batch.sample_state) @@ -1673,7 +1677,7 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream): scheduled_batch, batch_outputs) assert sample_state is not None, "Sampling failed" - with self.sm_disagg_request_lock: + with self.sm_disagg_lock: self._update_request_states(scheduled_batch) if self.previous_batch is not None: From 651d9c62f072c1db8051b169ded20b79563318d8 Mon Sep 17 00:00:00 2001 From: Qiang Xu Date: Fri, 7 Nov 2025 23:39:11 +0000 Subject: [PATCH 4/6] add feature combo guards Signed-off-by: Qiang Xu --- .../_torch/pyexecutor/py_executor_creator.py | 29 ++++++++----------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 8dd736cf09f..9c22508f082 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -351,6 +351,18 @@ def allocation_scope(current_stage: ExecutorMemoryType, validate_feature_combination(llm_args, model_engine, llm_args.sampler_type) if llm_args.sm_disagg_config is not None: + if llm_args.cache_transceiver_config is not None: + raise ValueError( + "SM-level disaggregation is not compatible with disaggregated serving." + ) + if llm_args.parallel_config.world_size > 1: + raise NotImplementedError( + "SM-level disaggregation is not supported with parallelism.") + if scheduler_config.capacity_scheduler_policy != CapacitySchedulerPolicy.GUARANTEED_NO_EVICT: + raise NotImplementedError( + "SM-level disaggregation is only supported with guaranteed no evict scheduler policy." + ) + with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX, RestoreMode.PINNED): ctx_llm_args = copy.copy(llm_args) @@ -367,23 +379,6 @@ def allocation_scope(current_stage: ExecutorMemoryType, else: ctx_model_engine = None - if llm_args.sm_disagg_config is not None: - with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_CTX, - RestoreMode.PINNED): - ctx_backend_config = copy.copy(pytorch_backend_config) - ctx_backend_config.use_cuda_graph = False - ctx_model_engine = PyTorchModelEngine( - model_path=checkpoint_dir, - llm_args=llm_args, - mapping=mapping, - attn_runtime_features=attn_runtime_features, - dist=dist, - spec_config=spec_config, - weight_sharing_model=model_engine.model, - ) - else: - ctx_model_engine = None - if has_draft_model_engine: with allocation_scope(ExecutorMemoryType.MODEL_ENGINE_DRAFT, RestoreMode.PINNED): From b84b03a6208bc23504cbf3300015f9f08d1c3197 Mon Sep 17 00:00:00 2001 From: Qiang Xu Date: Wed, 12 Nov 2025 01:44:15 +0000 Subject: [PATCH 5/6] fix reviewer commits Signed-off-by: Qiang Xu --- tensorrt_llm/_torch/pyexecutor/_util.py | 10 ++- .../pyexecutor/executor_request_queue.py | 27 +----- tensorrt_llm/_torch/pyexecutor/green_ctx.py | 82 +++++++++++++++++++ tensorrt_llm/_torch/pyexecutor/py_executor.py | 53 +++--------- .../_torch/pyexecutor/py_executor_creator.py | 1 + tensorrt_llm/_torch/pyexecutor/scheduler.py | 40 +++------ 6 files changed, 117 insertions(+), 96 deletions(-) create mode 100644 tensorrt_llm/_torch/pyexecutor/green_ctx.py diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index a79e6401703..eabd9b22e8e 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -29,7 +29,7 @@ from .guided_decoder import GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver -from .llm_request import ExecutorResponse +from .llm_request import ExecutorResponse, LlmRequestState from .mamba_cache_manager import MambaHybridCacheManager from .model_engine import PyTorchModelEngine from .py_executor import PyExecutor @@ -38,7 +38,7 @@ from .sampler import (EarlyStopSampler, EarlyStopWithMMResult, TorchSampler, TRTLLMSampler) from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, - SimpleScheduler, SmDisaggCtxScheduler) + SimpleScheduler) from .seq_slot_manager import SeqSlotManager GB = 1 << 30 @@ -801,8 +801,10 @@ def create_py_executor_instance( two_step_lookahead=mapping.has_pp()) mb_scheduler = BindMicroBatchScheduler( sm_disagg_config.context_max_batch_size, - sm_disagg_config.context_max_num_tokens, ctx_chunk_config) - ctx_scheduler = SmDisaggCtxScheduler(capacity_scheduler, mb_scheduler) + sm_disagg_config.context_max_num_tokens, + ctx_chunk_config, + no_schedule_after_state=LlmRequestState.GENERATION_IN_PROGRESS) + ctx_scheduler = SimpleScheduler(capacity_scheduler, mb_scheduler) else: ctx_scheduler = None diff --git a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py index 5479bdbb858..2d51561de05 100644 --- a/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py +++ b/tensorrt_llm/_torch/pyexecutor/executor_request_queue.py @@ -342,31 +342,12 @@ def fetch_new_requests( self, activate_requests: List[LlmRequest], num_active_requests_on_engine: int) -> List[LlmRequest]: - if self.is_sm_disagg: - return self._fetch_new_requests_sm_disagg( - len(activate_requests), num_active_requests_on_engine) - elif self.enable_attention_dp: + if self.enable_attention_dp: return self._fetch_new_requests_attention_dp(activate_requests) else: - return self._fetch_new_requests_attention_tp(len(activate_requests)) - - def _fetch_new_requests_sm_disagg( - self, num_active_requests: int, - num_active_requests_on_engine: int) -> List[LlmRequest]: - """Handle SM-level disaggregation request fetching.""" - total_max_num_active_requests = (self.max_num_active_requests + - num_active_requests - - num_active_requests_on_engine) - - # fetch and process requests into waiting queue - new_requests = self._fetch_and_process_requests( - num_active_requests_on_engine, - total_max_num_active_requests, - enable_attention_dp=False) - - # Merge requests and add to active list - merged_requests = self._merge_requests(new_requests) - return merged_requests + num_active_requests = num_active_requests_on_engine if self.is_sm_disagg else len( + activate_requests) + return self._fetch_new_requests_attention_tp(num_active_requests) def _fetch_new_requests_attention_tp( self, num_active_requests: int) -> List[LlmRequest]: diff --git a/tensorrt_llm/_torch/pyexecutor/green_ctx.py b/tensorrt_llm/_torch/pyexecutor/green_ctx.py new file mode 100644 index 00000000000..68a074d42cb --- /dev/null +++ b/tensorrt_llm/_torch/pyexecutor/green_ctx.py @@ -0,0 +1,82 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from cuda.bindings import driver + +from tensorrt_llm.runtime.generation import CUASSERT + + +def green_ctx_create_streams(res_list, device): + streams = [] + for res in res_list: + desc = CUASSERT(driver.cuDevResourceGenerateDesc([res], 1))[0] + green_ctx = CUASSERT( + driver.cuGreenCtxCreate( + desc, device, driver.CUgreenCtxCreate_flags.CU_GREEN_CTX_DEFAULT_STREAM + ) + )[0] + stream = CUASSERT( + driver.cuGreenCtxStreamCreate( + green_ctx, driver.CUstream_flags.CU_STREAM_NON_BLOCKING, 0 + ) + )[0] + stream = torch.cuda.get_stream_from_external(stream, device) + streams.append(stream) + return streams + + +def green_ctx_split_percent(sm_percent: float, device_id: int = 0): + device = CUASSERT(driver.cuDeviceGet(device_id))[0] + + res = CUASSERT( + driver.cuDeviceGetDevResource(device, driver.CUdevResourceType.CU_DEV_RESOURCE_TYPE_SM) + )[0] + sm_count = res.sm.smCount + + major = CUASSERT( + driver.cuDeviceGetAttribute( + driver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR, device + ) + )[0] + if major >= 9: + sm_min = 8 + sm_align = 8 + else: + sm_min = 4 if major == 8 else 2 + sm_align = 2 + + def green_ctx_split_aligned(sm_g1): + sm_g1 = round(sm_g1 / sm_align) * sm_align + sm_g1 = min(max(sm_g1, sm_min), sm_count - sm_min) + result = CUASSERT( + driver.cuDevSmResourceSplitByCount( + 1, # nbGroups + res, + 0, # useFlags + sm_g1, + ) + ) + res_split = (result[0][0], result[2]) + streams = green_ctx_create_streams(res_split, device) + return streams, res_split + + sm_g1 = round(sm_count * sm_percent) + sm_g2 = sm_count - sm_g1 + # Choose the split closer to sm_percent when sm_count is not divisible by sm_align + sm_g1_dist = min(sm_g1 % sm_align, sm_align - (sm_g1 % sm_align)) + sm_g2_dist = min(sm_g2 % sm_align, sm_align - (sm_g2 % sm_align)) + if sm_g1_dist <= sm_g2_dist: + (stream_g1, stream_g2), (res_g1, res_g2) = green_ctx_split_aligned(sm_g1) + else: + (stream_g2, stream_g1), (res_g2, res_g1) = green_ctx_split_aligned(sm_g2) + return (stream_g1, stream_g2), (res_g1, res_g2) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 49c44b8a447..3f8e2641b72 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -42,6 +42,7 @@ from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.speculation_gate import SpeculationGate from .executor_request_queue import ExecutorRequestQueue, RequestQueueItem +from .green_ctx import green_ctx_split_percent from .guided_decoder import GuidedDecoder from .handle_additional_outputs import HandleAdditionalOutputs from .handle_logits import HandleLogits @@ -215,9 +216,10 @@ def __init__(self, self.responses = {} self.result_wait_queues = {} - self.sm_disagg_lock = threading.Lock() - self.ctx_request_cv = threading.Condition(self.sm_disagg_lock) - self.gen_request_cv = threading.Condition(self.sm_disagg_lock) + if self.ctx_model_engine is not None: + self.sm_disagg_lock = threading.Lock() + self.ctx_request_cv = threading.Condition(self.sm_disagg_lock) + self.gen_request_cv = threading.Condition(self.sm_disagg_lock) # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( @@ -229,6 +231,9 @@ def __init__(self, self.max_input_len = max_input_len # _executor_loop private data self.max_num_active_requests = model_engine.get_max_num_sequences() + if self.ctx_model_engine is not None: + self.max_num_active_requests += ctx_model_engine.get_max_num_sequences( + ) self.active_requests: List[LlmRequest] = [] self.expected_num_active_requests = 0 self.ctx_in_transmission_requests = dict() @@ -1694,7 +1699,11 @@ def _executor_loop_sm_disagg_gen_overlap(self, stream): iter_stats=iter_stats) def _executor_loop_sm_disagg(self): - stream_ctx, stream_gen = self.split_device_green_ctx() + (stream_ctx, stream_gen), (res_ctx, res_gen) = green_ctx_split_percent( + self.sm_disagg_ctx_sm_percent, self.device_id) + logger.info( + f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase." + ) thread_ctx = threading.Thread(target=self._executor_loop_sm_disagg_ctx, args=(stream_ctx, ), @@ -1705,42 +1714,6 @@ def _executor_loop_sm_disagg(self): thread_ctx.join() - def split_device_green_ctx(self): - device = torch.device("cuda", self.device_id) - device_properties = torch.cuda.get_device_properties(device) - sm_count = device_properties.multi_processor_count - if device_properties.major >= 9: - sm_min = 8 - sm_align = 8 - else: - sm_min = 4 if device_properties.major == 8 else 2 - sm_align = 2 - - from flashinfer import green_ctx - - def split_device_green_ctx_aligned(sm_s1): - sm_s1 = round(sm_s1 / sm_align) * sm_align - sm_s1 = min(max(sm_s1, sm_min), sm_count - sm_min) - return green_ctx.split_device_green_ctx_by_sm_count(device, [sm_s1]) - - sm_ctx = round(sm_count * self.sm_disagg_ctx_sm_percent) - sm_gen = sm_count - sm_ctx - # Choose the split closer to user-specified percentage when sm_count is not divisible by sm_align - sm_ctx_dist = min(sm_ctx % sm_align, sm_align - (sm_ctx % sm_align)) - sm_gen_dist = min(sm_gen % sm_align, sm_align - (sm_gen % sm_align)) - if sm_gen_dist < sm_ctx_dist: - (stream_gen, - stream_ctx), (res_gen, - res_ctx) = split_device_green_ctx_aligned(sm_gen) - else: - (stream_ctx, - stream_gen), (res_ctx, - res_gen) = split_device_green_ctx_aligned(sm_ctx) - logger.info( - f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase." - ) - return stream_ctx, stream_gen - def _accept_draft_tokens( self, scheduled_batch: ScheduledRequests, target_outputs: SampleStateTensors, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 9c22508f082..5019dd26296 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -374,6 +374,7 @@ def allocation_scope(current_stage: ExecutorMemoryType, attn_runtime_features=attn_runtime_features, dist=dist, spec_config=spec_config, + is_sm_disagg_ctx_phase=True, weight_sharing_model=model_engine.model, ) else: diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index e62149b276f..bfc9eed8fed 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -7,7 +7,7 @@ from tensorrt_llm.bindings import internal as tb_internal from tensorrt_llm.llmapi.llm_args import CapacitySchedulerPolicy -from .llm_request import LlmRequest, LlmRequestState, get_context_requests +from .llm_request import LlmRequest, LlmRequestState RequestList = list[LlmRequest] @@ -79,6 +79,9 @@ def __init__( scheduler_policy: CapacitySchedulerPolicy = CapacitySchedulerPolicy. GUARANTEED_NO_EVICT, two_step_lookahead: bool = False, + no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, + no_schedule_after_state: LlmRequestState = LlmRequestState. + GENERATION_COMPLETE, ): super(BindCapacityScheduler, self).__init__() self.kv_cache_manager = kv_cache_manager @@ -89,8 +92,8 @@ def __init__( capacity_scheduler_policy=scheduler_policy._to_pybind(), has_kv_cache_manager=kv_cache_manager is not None, two_step_lookahead=two_step_lookahead, - no_schedule_until_state=LlmRequestState.CONTEXT_INIT, - no_schedule_after_state=LlmRequestState.GENERATION_COMPLETE) + no_schedule_until_state=no_schedule_until_state, + no_schedule_after_state=no_schedule_after_state) def schedule_request( self, active_requests: RequestList @@ -175,6 +178,9 @@ def __init__( max_batch_size: int, max_num_tokens: int = None, ctx_chunk_config: Optional[Tuple[StrEnum, int]] = None, + no_schedule_until_state: LlmRequestState = LlmRequestState.CONTEXT_INIT, + no_schedule_after_state: LlmRequestState = LlmRequestState. + GENERATION_COMPLETE, ) -> None: super(BindMicroBatchScheduler, self).__init__() self.max_batch_size = max_batch_size @@ -186,7 +192,8 @@ def __init__( ctx_chunk_config[0]._to_pybind(), ctx_chunk_config[1]) self.impl = tb_internal.algorithms.MicroBatchScheduler( - ctx_chunk_config_cpp, max_num_tokens) + ctx_chunk_config_cpp, max_num_tokens, no_schedule_until_state, + no_schedule_after_state) def schedule( self, active_requests: RequestList, inflight_request_ids: set[int] @@ -216,28 +223,3 @@ def schedule_request(self, active_requests: RequestList, list(generation_requests), list(paused_requests), list(fitting_disagg_gen_init_requests), len(fitting_requests)) - - -class SmDisaggCtxScheduler(RequestScheduler): - - def __init__(self, capacity_scheduler: CapacityScheduler, - micro_batch_scheduler: MicroBatchScheduler): - super(SmDisaggCtxScheduler, self).__init__() - self.capacity_scheduler = capacity_scheduler - self.micro_batch_scheduler = micro_batch_scheduler - - def schedule_request(self, active_requests: RequestList, - inflight_request_ids: set[int]) -> SchedulerOutput: - fitting_requests, fitting_disagg_gen_init_requests, paused_requests = self.capacity_scheduler.schedule_request( - active_requests) - - fitting_requests = get_context_requests(fitting_requests) - - context_requests, generation_requests = self.micro_batch_scheduler.schedule( - fitting_requests, inflight_request_ids) - # Convert from binding type RequestVector to list[LlmRequest], - # so Python fields on LlmRequest won't be stripped away - return SchedulerOutput(list(context_requests), - list(generation_requests), list(paused_requests), - list(fitting_disagg_gen_init_requests), - len(fitting_requests)) From cd20a27a985e93c7c21fbe2528ec0fa33455a984 Mon Sep 17 00:00:00 2001 From: Qiang Xu Date: Fri, 14 Nov 2025 18:33:41 +0000 Subject: [PATCH 6/6] refactor sm disagg executor loops into main executor loops Signed-off-by: Qiang Xu --- tensorrt_llm/_torch/pyexecutor/py_executor.py | 528 ++++++++---------- 1 file changed, 233 insertions(+), 295 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 3f8e2641b72..874714ac836 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -1,5 +1,6 @@ import dataclasses import datetime +import enum import functools import os import pickle # nosec B403 @@ -113,6 +114,12 @@ class BatchStatePP(BatchState): scheduled_ctx_reqs: list[LlmRequest] = None +class ExecutorLoopPhase(enum.StrEnum): + IFB = "IFB" + SM_DISAGG_CTX = "SM_DISAGG_CTX" + SM_DISAGG_GEN = "SM_DISAGG_GEN" + + class PyExecutor: def __init__(self, @@ -165,8 +172,6 @@ def __init__(self, self.ctx_scheduler = ctx_scheduler self.ctx_model_engine = ctx_model_engine - self.profile_sm_disagg_ctx_range = bool( - os.environ.get(PROFILE_SM_DISAGG_CTX_RANGE_ENV_VAR_NAME, None)) # enqueue and _fetch_new_requests used data self.active = True @@ -216,10 +221,9 @@ def __init__(self, self.responses = {} self.result_wait_queues = {} - if self.ctx_model_engine is not None: - self.sm_disagg_lock = threading.Lock() - self.ctx_request_cv = threading.Condition(self.sm_disagg_lock) - self.gen_request_cv = threading.Condition(self.sm_disagg_lock) + self.executor_lock = threading.Lock() + self.ctx_request_cv = threading.Condition(self.executor_lock) + self.gen_request_cv = threading.Condition(self.executor_lock) # kv cache events self.kv_cache_manager = self.resource_manager.resource_managers.get( @@ -242,7 +246,6 @@ def __init__(self, 0) self.previous_batch: Optional[BatchState] = None self.has_previous_draft_tokens = False - self.num_scheduled_requests: int = 0 self.benchmark_req_queues_size = int( os.environ.get("TLLM_BENCHMARK_REQ_QUEUES_SIZE", 0)) self._disable_mpi = mpi_disabled() @@ -542,8 +545,7 @@ def should_stop_processing(self): def _profiler(self, model_engine: Optional[ModelEngine] = None, stream: Optional[torch.cuda.Stream] = None, - phase_name: Optional[str] = None, - enable_profiler: bool = True): + phase: Optional[ExecutorLoopPhase] = ExecutorLoopPhase.IFB): if model_engine is None: model_engine = self.model_engine @@ -563,6 +565,11 @@ def _profiler(self, torch_trace_path = os.environ.get(PROFILE_TRACE_ENV_VAR_NAME, None) profile_start_stop = os.environ.get(PROFILE_START_STOP_ENV_VAR_NAME, None) + profile_sm_disagg_ctx_range = bool( + os.environ.get(PROFILE_SM_DISAGG_CTX_RANGE_ENV_VAR_NAME, None)) + enable_profiler = phase == ExecutorLoopPhase.IFB or ( + (phase == ExecutorLoopPhase.SM_DISAGG_GEN) + ^ profile_sm_disagg_ctx_range) enable_torch_trace = bool(enable_profiler and torch_trace_path and profile_start_stop) if torch_trace_path and profile_start_stop is None: @@ -582,7 +589,7 @@ def _profiler(self, record_shapes=True, with_modules=True) - def profile_step(): + def profile_step(iter_stats: Optional[IterationStats] = None): nonlocal it, enabled, start_time, start_event_1, end_event_1, start_event_2, end_event_2, prev_device_step_time if it in self.profile_stop_iters and enable_profiler and not self.is_warmup: assert enabled, "Inconsistent CUDA profiling state" @@ -618,7 +625,7 @@ def profile_step(): formatted_timestamp = datetime.datetime.now().strftime( "%Y-%m-%d %H:%M:%S") logger.info( - (f"phase = {phase_name}, " if phase_name else "") + + f"phase = {phase}, " f"iter = {model_engine.iter_counter}, " f"global_rank = {self.global_rank}, " f"rank = {self.dist.rank}, " @@ -627,7 +634,7 @@ def profile_step(): f"host_step_time = {host_step_time}ms, " f"prev_device_step_time = {prev_device_step_time}, " f"timestamp = {formatted_timestamp}, " - f"num_scheduled_requests: {self.num_scheduled_requests}, " + f"num_scheduled_requests: {iter_stats.inflight_batching_stats.num_scheduled_requests if iter_stats else 'N/A'}, " f"states = {model_engine.iter_states}") it += 1 @@ -832,7 +839,7 @@ def _executor_loop_pp(self): iter_start_time = time.time() iter_stats = None while True: - profile_step() + profile_step(iter_stats) if self.enable_iter_perf_stats: iter_start_time = time.time() new_requests = self._fetch_and_activate_new_requests() @@ -866,8 +873,6 @@ def _executor_loop_pp(self): ) self._check_disagg_ctx_cache_transfer_status(1) - self.num_scheduled_requests = scheduled_batch.batch_size - logger.debug( f'has {len(self.active_requests)} active_request, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' @@ -1127,13 +1132,66 @@ def _prepare_and_schedule_batch(self): ) self._check_disagg_ctx_cache_transfer_status(1) - self.num_scheduled_requests = scheduled_batch.batch_size logger.debug( f'has {len(self.active_requests)} active_request, ' f'scheduled {len(scheduled_batch.context_requests)} context requests and ' f'{len(scheduled_batch.generation_requests)} generation requests') return scheduled_batch, iter_stats + def _prepare_and_schedule_batch_sm_disagg_ctx(self): + new_requests = self._fetch_and_activate_new_requests() + + iter_stats = None + if self.enable_iter_perf_stats: + iter_stats = self._get_init_iter_stats( + len(new_requests), + self.executor_request_queue. + get_new_active_requests_queue_latency()) + + with self.executor_lock: + ctx_requests = get_context_requests(self.active_requests) + if self.is_shutdown and len(ctx_requests) == 0 \ + and self.executor_request_queue.get_waiting_queue_size() == 0: + self.ctx_request_cv.notify() + return None, None + + scheduled_batch, _, _ = self._schedule(scheduler=self.ctx_scheduler) + + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch.context_requests)} context requests' + ) + + if scheduled_batch.batch_size == 0 \ + and (len(ctx_requests) > 0 or self.executor_request_queue.get_waiting_queue_size() > 0): + self.gen_request_cv.wait() + + return scheduled_batch, iter_stats + + def _prepare_and_schedule_batch_sm_disagg_gen(self): + if self.should_stop_processing: + return None, None + + iter_stats = None + if self.enable_iter_perf_stats: + iter_stats = self._get_init_iter_stats( + num_new_active_requests=0, + new_active_requests_queue_latency_ms=0) + + with self.executor_lock: + gen_requests = get_generation_requests(self.active_requests) + scheduled_batch, _, _ = self._schedule(active_requests=gen_requests) + + logger.debug( + f'has {len(self.active_requests)} active_request, ' + f'scheduled {len(scheduled_batch.generation_requests)} generation requests' + ) + + if scheduled_batch.batch_size == 0: + self.ctx_request_cv.wait() + + return scheduled_batch, iter_stats + def _kv_connector_start_batch(self, scheduled_batch): if self.kv_connector_manager: self.kv_connector_manager.take_scheduled_requests_pending_load( @@ -1160,45 +1218,60 @@ def _kv_connector_wait_for_save(self): self.kv_connector_manager.worker.wait_for_save( torch.cuda.current_stream()) - def _executor_loop(self): + def _executor_loop(self, + stream: Optional[torch.cuda.Stream] = None, + phase: ExecutorLoopPhase = ExecutorLoopPhase.IFB): torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + + model_engine = self.ctx_model_engine if phase == ExecutorLoopPhase.SM_DISAGG_CTX else self.model_engine + + with self._profiler(model_engine, stream, phase) as profile_step: sample_state = None iter_start_time = time.time() iter_stats = None while True: - profile_step() + profile_step(iter_stats) if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if phase == ExecutorLoopPhase.SM_DISAGG_CTX: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch_sm_disagg_ctx( + ) + elif phase == ExecutorLoopPhase.SM_DISAGG_GEN: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch_sm_disagg_gen( + ) + else: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch( + ) + self._handle_control_request() if scheduled_batch is None: break - self._pause_requests(scheduled_batch.paused_requests) + with self.executor_lock: + self._pause_requests(scheduled_batch.paused_requests) - finished_requests = [] + finished_requests = [] - can_queue = self._can_queue(scheduled_batch) - if can_queue: - if self.kv_cache_transceiver: - # For generation requests which have completed KV cache transfer - self._prepare_disagg_gen_transmission_complete( - scheduled_batch) + can_queue = self._can_queue(scheduled_batch) + if can_queue: + if self.kv_cache_transceiver: + # For generation requests which have completed KV cache transfer + self._prepare_disagg_gen_transmission_complete( + scheduled_batch) - # Return the first token to the client - self._handle_first_token_response(scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) + # Return the first token to the client + self._handle_first_token_response(scheduled_batch) + self.resource_manager.prepare_resources(scheduled_batch) - self._kv_connector_start_batch(scheduled_batch) + self._kv_connector_start_batch(scheduled_batch) - # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed - if self.kv_connector_manager: - can_queue = self._can_queue(scheduled_batch) + # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed + if self.kv_connector_manager: + can_queue = self._can_queue(scheduled_batch) if can_queue: # init_disagg_gen_requests must be before drafter loop, otherwise draft requests do not have initialized matchers. @@ -1225,49 +1298,61 @@ def _executor_loop(self): if hasattr(self.drafter, "guided_decoder"): self.guided_decoder.rollback_draft_tokens() - batch_outputs = self._forward_step(scheduled_batch) - if self.guided_decoder is not None: - self.guided_decoder.execute(batch_outputs['logits']) + with torch.cuda.stream(stream): + batch_outputs = self._forward_step( + scheduled_batch, model_engine=model_engine) + if self.guided_decoder is not None: + self.guided_decoder.execute(batch_outputs['logits']) - sample_state = self._sample_async(scheduled_batch, - batch_outputs) - if self.drafter is not None: - self.drafter.run_drafter_post(scheduled_batch, - self.resource_manager, - self.is_warmup) - - self._update_request_states(scheduled_batch) - self._update_requests(sample_state, self.resource_manager) - if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: - for req in scheduled_batch.context_requests: - if req.is_context_only_request and ( - req.is_context_finished - or req.is_finished_due_to_length): - block_id = self.kv_cache_manager.store_blocks_for_reuse( - req, True) - self.ctx_in_transmission_requests[ - req.py_request_id] = ( - (req, block_id, - self.ctx_in_transmission_counter)) + sample_state = self._sample_async( + scheduled_batch, batch_outputs) + if self.drafter is not None: + self.drafter.run_drafter_post( + scheduled_batch, self.resource_manager, + self.is_warmup) + # To avoid long sync time in critical section below + if phase == ExecutorLoopPhase.SM_DISAGG_CTX or phase == ExecutorLoopPhase.SM_DISAGG_GEN: + sample_state.sampler_event.synchronize() - if self.kv_cache_transceiver: - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests) - # For context only req in transmission, we reset the state since sampler might have changed it - for req in ctx_transmission_reqs: - req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS - - self._handle_canceled_requests() - finished_requests = self._handle_responses() - attn_metadata = getattr(self.model_engine, 'attn_metadata', - None) - kv_cache_dtype_byte_size = getattr( - self.model_engine, 'kv_cache_dtype_byte_size', None) - self.resource_manager.update_resources( - scheduled_batch, attn_metadata, - kv_cache_dtype_byte_size) - if self.enable_kv_cache_events: - self._add_kv_cache_events() + with self.executor_lock: + self._update_request_states(scheduled_batch) + self._update_requests(sample_state, + self.resource_manager) + if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: + for req in scheduled_batch.context_requests: + if req.is_context_only_request and ( + req.is_context_finished + or req.is_finished_due_to_length): + block_id = self.kv_cache_manager.store_blocks_for_reuse( + req, True) + self.ctx_in_transmission_requests[ + req.py_request_id] = ( + (req, block_id, + self.ctx_in_transmission_counter)) + + if self.kv_cache_transceiver: + ctx_transmission_reqs = self._send_disagg_ctx_cache( + scheduled_batch.context_requests) + # For context only req in transmission, we reset the state since sampler might have changed it + for req in ctx_transmission_reqs: + req.state = LlmRequestState.DISAGG_CONTEXT_TRANS_IN_PROGRESS + + self._handle_canceled_requests() + finished_requests = self._handle_responses() + attn_metadata = getattr(model_engine, 'attn_metadata', + None) + kv_cache_dtype_byte_size = getattr( + model_engine, 'kv_cache_dtype_byte_size', None) + self.resource_manager.update_resources( + scheduled_batch, attn_metadata, + kv_cache_dtype_byte_size) + if self.enable_kv_cache_events: + self._add_kv_cache_events() + + if phase == ExecutorLoopPhase.SM_DISAGG_CTX: + self.ctx_request_cv.notify() + elif phase == ExecutorLoopPhase.SM_DISAGG_GEN: + self.gen_request_cv.notify() if self.kv_cache_transceiver and self.ctx_in_transmission_requests: self._check_kv_transfer_timeout() @@ -1276,8 +1361,8 @@ def _executor_loop(self): self._kv_connector_terminate_requests() if self.enable_iter_perf_stats and sample_state is not None: - iter_stats.iter_counter = self.model_engine.iter_counter - iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[ + iter_stats.iter_counter = model_engine.iter_counter + iter_stats.inflight_batching_stats.num_ctx_tokens = model_engine.iter_states[ 'num_ctx_tokens'] self._process_iter_stats( finished_requests, self.active_requests, @@ -1351,22 +1436,35 @@ def control_action(self): self.control_action_done.set() self.control_request_barrier.clear() - def _executor_loop_overlap(self): + def _executor_loop_overlap( + self, + stream: Optional[torch.cuda.Stream] = None, + phase: ExecutorLoopPhase = ExecutorLoopPhase.IFB): + assert phase != ExecutorLoopPhase.SM_DISAGG_CTX, ( + "SM disagg context phase only supports non-overlapping executor loop." + ) torch.cuda.set_device(self.device_id) # ensure the context is created, otherwise, some MPI calls will fail. CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler() as profile_step: + + with self._profiler(self.model_engine, stream, phase) as profile_step: iter_start_time = time.time() iter_stats = None target_inputs = None previous_tensors_device = None can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True while True: - profile_step() + profile_step(iter_stats) if self.enable_iter_perf_stats: iter_start_time = time.time() - scheduled_batch, iter_stats = self._prepare_and_schedule_batch() + if phase == ExecutorLoopPhase.SM_DISAGG_GEN: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch_sm_disagg_gen( + ) + else: + scheduled_batch, iter_stats = self._prepare_and_schedule_batch( + ) + self._handle_control_request() if scheduled_batch is None: @@ -1402,32 +1500,36 @@ def _executor_loop_overlap(self): else: can_forward = True - self._pause_requests(scheduled_batch.paused_requests) + with self.executor_lock: + self._pause_requests(scheduled_batch.paused_requests) - can_queue = self._can_queue(scheduled_batch) - if can_queue: - if self.kv_cache_transceiver: - # For generation requests which have completed KV cache transfer - self._prepare_disagg_gen_transmission_complete( - scheduled_batch) - self.resource_manager.prepare_resources(scheduled_batch) + can_queue = self._can_queue(scheduled_batch) + if can_queue: + if self.kv_cache_transceiver: + # For generation requests which have completed KV cache transfer + self._prepare_disagg_gen_transmission_complete( + scheduled_batch) + self.resource_manager.prepare_resources(scheduled_batch) - self._kv_connector_start_batch(scheduled_batch) + self._kv_connector_start_batch(scheduled_batch) - # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed - if self.kv_connector_manager: - can_queue = self._can_queue(scheduled_batch) + # if using a kv connector, we need to call can_queue again since scheduled_batch might have changed + if self.kv_connector_manager: + can_queue = self._can_queue(scheduled_batch) if can_queue: - # The generation requests that are do not have batch_idx, + # The generation requests that are do not have batch_idx (for disaggregated serving), + # or those just finished context phase (for SM-level disaggregation), # needs to be in front of the batch due to the assumptions - # made in model_engine.py::_forward_step. This is only important - # for disaggregated serving. For non-disaggregated serving, + # made in model_engine.py::_forward_step. For non-disaggregated serving, # the generation requests always have batch_idx. scheduled_batch.generation_requests = sorted( # stable sort scheduled_batch.generation_requests, - key=lambda req: int(req.py_batch_idx is not None), + key=lambda req: + int(req.py_batch_idx is not None or phase == + ExecutorLoopPhase.SM_DISAGG_GEN and req. + max_num_generated_tokens > 0), ) if self.kv_cache_transceiver: @@ -1459,11 +1561,18 @@ def _executor_loop_overlap(self): else: previous_tensors_device = self.previous_batch and self.previous_batch.sample_state and self.previous_batch.sample_state.device - batch_outputs = self._forward_step(scheduled_batch, - previous_tensors_device) + with torch.cuda.stream(stream): + batch_outputs = self._forward_step( + scheduled_batch, previous_tensors_device) + # To avoid long sync time in critical section below + if phase == ExecutorLoopPhase.SM_DISAGG_GEN and self.previous_batch is not None: + self.previous_batch.sample_state.sampler_event.synchronize( + ) if self.previous_batch is not None: - self._update_requests(self.previous_batch.sample_state) + with self.executor_lock: + self._update_requests( + self.previous_batch.sample_state) if self.block_reuse_enabled and not self.kv_cache_manager.is_vswa and self.kv_cache_transceiver: for req in self.previous_batch.sample_state.scheduled_requests.context_requests: @@ -1486,18 +1595,23 @@ def _executor_loop_overlap(self): self.guided_decoder.add_batch(scheduled_batch) self.guided_decoder.execute(batch_outputs['logits']) - sample_state = self._sample_async(scheduled_batch, - batch_outputs) - assert sample_state is not None, "Sampling failed" + with torch.cuda.stream(stream): + sample_state = self._sample_async( + scheduled_batch, batch_outputs) + assert sample_state is not None, "Sampling failed" - self._update_request_states(scheduled_batch) + with self.executor_lock: + self._update_request_states(scheduled_batch) - ctx_transmission_reqs = self._send_disagg_ctx_cache( - scheduled_batch.context_requests - ) if self.kv_cache_transceiver else [] + ctx_transmission_reqs = self._send_disagg_ctx_cache( + scheduled_batch.context_requests + ) if self.kv_cache_transceiver else [] - if self.previous_batch is not None: - self._process_previous_batch() + if self.previous_batch is not None: + self._process_previous_batch() + + if phase == ExecutorLoopPhase.SM_DISAGG_GEN: + self.gen_request_cv.notify() if self.enable_iter_perf_stats: iter_stats.iter_counter = self.model_engine.iter_counter @@ -1516,188 +1630,6 @@ def _executor_loop_overlap(self): self._kv_connector_terminate_requests() - def _executor_loop_sm_disagg_ctx(self, stream): - torch.cuda.set_device(self.device_id) - # ensure the context is created, otherwise, some MPI calls will fail. - CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler( - model_engine=self.ctx_model_engine, - stream=stream, - phase_name='context', - enable_profiler=self.profile_sm_disagg_ctx_range, - ) as profile_step: - sample_state = None - iter_start_time = time.time() - iter_stats = None - while True: - profile_step() - if self.enable_iter_perf_stats: - iter_start_time = time.time() - - new_requests = self._fetch_and_activate_new_requests() - - if self.enable_iter_perf_stats: - iter_stats = self._get_init_iter_stats( - len(new_requests), - self.executor_request_queue. - get_new_active_requests_queue_latency()) - - with self.sm_disagg_lock: - ctx_requests = get_context_requests(self.active_requests) - if self.is_shutdown and len(ctx_requests) == 0 \ - and self.executor_request_queue.get_waiting_queue_size() == 0: - self.ctx_request_cv.notify() - break - - scheduled_batch, _, _ = self._schedule( - scheduler=self.ctx_scheduler) - - logger.debug( - f'has {len(self.active_requests)} active_request, ' - f'scheduled {len(scheduled_batch.context_requests)} context requests' - ) - if scheduled_batch.batch_size == 0 \ - and (len(ctx_requests) > 0 or self.executor_request_queue.get_waiting_queue_size() > 0): - self.gen_request_cv.wait() - continue - - finished_requests = [] - - if scheduled_batch.batch_size > 0 or ( - self.enable_attention_dp and self.dist.tp_size > 1): - with self.sm_disagg_lock: - self.resource_manager.prepare_resources(scheduled_batch) - - with torch.cuda.stream(stream): - batch_outputs = self._forward_step( - scheduled_batch, model_engine=self.ctx_model_engine) - sample_state = self._sample_async( - scheduled_batch, batch_outputs) - # To avoid long sync time in critical section below - sample_state.sampler_event.synchronize() - - with self.sm_disagg_lock: - self._update_request_states(scheduled_batch) - self._update_requests(sample_state, - self.resource_manager) - self._handle_canceled_requests() - finished_requests = self._handle_responses() - - attn_metadata = getattr(self.ctx_model_engine, - 'attn_metadata', None) - kv_cache_dtype_byte_size = getattr( - self.ctx_model_engine, 'kv_cache_dtype_byte_size', - None) - self.resource_manager.update_resources( - scheduled_batch, attn_metadata, - kv_cache_dtype_byte_size) - if self.enable_kv_cache_events: - self._add_kv_cache_events() - - self.ctx_request_cv.notify() - - if self.enable_iter_perf_stats and sample_state is not None: - iter_stats.iter_counter = self.ctx_model_engine.iter_counter - iter_stats.inflight_batching_stats.num_ctx_tokens = self.ctx_model_engine.iter_states[ - 'num_ctx_tokens'] - self._process_iter_stats( - finished_requests, self.active_requests, - BatchState(sample_state=sample_state, - iter_stats=iter_stats, - iter_start_time=iter_start_time)) - - def _executor_loop_sm_disagg_gen_overlap(self, stream): - torch.cuda.set_device(self.device_id) - # ensure the context is created, otherwise, some MPI calls will fail. - CUASSERT(cudart.cudaSetDevice(self.device_id)) - with self._profiler( - stream=stream, - phase_name='generation', - enable_profiler=not self.profile_sm_disagg_ctx_range, - ) as profile_step: - iter_start_time = time.time() - iter_stats = None - while True: - profile_step() - if self.enable_iter_perf_stats: - iter_start_time = time.time() - - if self.should_stop_processing: - break - - if self.enable_iter_perf_stats: - iter_stats = self._get_init_iter_stats( - num_new_active_requests=0, - new_active_requests_queue_latency_ms=0) - - with self.sm_disagg_lock: - self._pad_attention_dp_dummy_request() - - gen_requests = get_generation_requests(self.active_requests) - scheduled_batch, _, _ = self._schedule( - active_requests=gen_requests) - - self.num_scheduled_requests = scheduled_batch.batch_size - logger.debug( - f'has {len(self.active_requests)} active_request, ' - f'scheduled {len(scheduled_batch.generation_requests)} generation requests' - ) - - if scheduled_batch.batch_size == 0: - self.ctx_request_cv.wait() - continue - - self._pause_requests(scheduled_batch.paused_requests) - - if scheduled_batch.batch_size > 0: - with self.sm_disagg_lock: - self.resource_manager.prepare_resources(scheduled_batch) - - # The generation requests that just finished context phase - # needs to be in front of the batch due to the assumptions - # made in model_engine.py::_forward_step. - scheduled_batch.generation_requests = sorted( # stable sort - scheduled_batch.generation_requests, - key=lambda req: int(req.max_num_generated_tokens > 0), - ) - - previous_tensors_device = self.previous_batch and self.previous_batch.sample_state \ - and self.previous_batch.sample_state.device - - with torch.cuda.stream(stream): - batch_outputs = self._forward_step( - scheduled_batch, previous_tensors_device) - # To avoid long sync time in critical section below - if self.previous_batch is not None: - self.previous_batch.sample_state.sampler_event.synchronize( - ) - - with self.sm_disagg_lock: - if self.previous_batch is not None: - self._update_requests( - self.previous_batch.sample_state) - - with torch.cuda.stream(stream): - sample_state = self._sample_async( - scheduled_batch, batch_outputs) - assert sample_state is not None, "Sampling failed" - - with self.sm_disagg_lock: - self._update_request_states(scheduled_batch) - - if self.previous_batch is not None: - self._process_previous_batch() - - self.gen_request_cv.notify() - - if self.enable_iter_perf_stats: - iter_stats.iter_counter = self.model_engine.iter_counter - - self.previous_batch = BatchState( - sample_state=sample_state, - iter_start_time=iter_start_time, - iter_stats=iter_stats) - def _executor_loop_sm_disagg(self): (stream_ctx, stream_gen), (res_ctx, res_gen) = green_ctx_split_percent( self.sm_disagg_ctx_sm_percent, self.device_id) @@ -1705,12 +1637,18 @@ def _executor_loop_sm_disagg(self): f"Green contexts allocated {res_ctx.sm.smCount} SMs for context phase and {res_gen.sm.smCount} SMs for generation phase." ) - thread_ctx = threading.Thread(target=self._executor_loop_sm_disagg_ctx, - args=(stream_ctx, ), + # Context phase only supports non-overlapping executor loop + thread_ctx = threading.Thread(target=self._executor_loop, + args=(stream_ctx, + ExecutorLoopPhase.SM_DISAGG_CTX), daemon=True) thread_ctx.start() - self._executor_loop_sm_disagg_gen_overlap(stream_gen) + if self.disable_overlap_scheduler: + self._executor_loop(stream_gen, ExecutorLoopPhase.SM_DISAGG_GEN) + else: + self._executor_loop_overlap(stream_gen, + ExecutorLoopPhase.SM_DISAGG_GEN) thread_ctx.join()