diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 5b835dc2093..759ccd61151 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -117,6 +117,7 @@ def get_server_address(self): async def launch_server(self): from tensorrt_llm import AsyncLLM from tensorrt_llm.llmapi import CapacitySchedulerPolicy, CudaGraphConfig, KvCacheConfig, SchedulerConfig + from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType, SleepConfig from tensorrt_llm.serve import OpenAIServer assert self.config.pipeline_model_parallel_size == 1, "pipeline_model_parallel_size > 1 is not supported yet" @@ -164,7 +165,14 @@ async def launch_server(self): "placement_groups": self.pgs, "placement_bundle_indices": self.bundle_indices, "per_worker_gpu_share": per_worker_gpu_share, - "enable_sleep": self.config.enable_sleep_mode, + "sleep_config": SleepConfig( + restore_modes={ + ExecutorMemoryType.MODEL_WEIGHTS_MAIN: "NONE", + ExecutorMemoryType.KV_CACHE: "NONE", + } + ) + if self.config.enable_sleep_mode + else None, "allreduce_strategy": "NCCL", "sampler_type": "TRTLLMSampler", **engine_kwargs, @@ -348,8 +356,8 @@ def get_pgs_and_bundle_indices(self) -> tuple[list[PlacementGroup], list[list[in local_bundle_index = self.world_size * self.replica_rank while local_bundle_index >= self.resource_pool.pgs[start_pg_index].bundle_count: - start_pg_index += 1 local_bundle_index -= self.resource_pool.pgs[start_pg_index].bundle_count + start_pg_index += 1 assert ( start_pg_index < len(self.resource_pool.pgs) and local_bundle_index < self.resource_pool.pgs[start_pg_index].bundle_count @@ -386,7 +394,6 @@ def get_pgs_and_bundle_indices(self) -> tuple[list[PlacementGroup], list[list[in return pgs, bundle_indices async def launch_servers(self): - assert self.nnodes == 1, "TRTLLMReplica doesn't support multiple nodes for single replica yet." assert self.resource_pool.pgs is not None, "placement groups are not initialized" pgs, bundle_indices = self.get_pgs_and_bundle_indices() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 606cb4b019f..a0c4ff29445 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -260,6 +260,7 @@ async def update_weights(self, weights: dict[str, str]): class ServerAdapter(BaseRollout): + # TODO: change to non hard-coded _WEIGHTS_TAGS = [ "sampler", "drafter", @@ -268,7 +269,9 @@ class ServerAdapter(BaseRollout): "model_extra", "executor_extra", "model", + "model_weights", "draft_model", + "draft_model_weights", ] @staticmethod