Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ def get_server_address(self):
async def launch_server(self):
from tensorrt_llm import AsyncLLM
from tensorrt_llm.llmapi import CapacitySchedulerPolicy, CudaGraphConfig, KvCacheConfig, SchedulerConfig
from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType, SleepConfig
from tensorrt_llm.serve import OpenAIServer

assert self.config.pipeline_model_parallel_size == 1, "pipeline_model_parallel_size > 1 is not supported yet"
Expand Down Expand Up @@ -164,7 +165,14 @@ async def launch_server(self):
"placement_groups": self.pgs,
"placement_bundle_indices": self.bundle_indices,
"per_worker_gpu_share": per_worker_gpu_share,
"enable_sleep": self.config.enable_sleep_mode,
"sleep_config": SleepConfig(
restore_modes={
ExecutorMemoryType.MODEL_WEIGHTS_MAIN: "NONE",
ExecutorMemoryType.KV_CACHE: "NONE",
}
)
if self.config.enable_sleep_mode
else None,
"allreduce_strategy": "NCCL",
"sampler_type": "TRTLLMSampler",
**engine_kwargs,
Expand Down Expand Up @@ -348,8 +356,8 @@ def get_pgs_and_bundle_indices(self) -> tuple[list[PlacementGroup], list[list[in
local_bundle_index = self.world_size * self.replica_rank

while local_bundle_index >= self.resource_pool.pgs[start_pg_index].bundle_count:
start_pg_index += 1
local_bundle_index -= self.resource_pool.pgs[start_pg_index].bundle_count
start_pg_index += 1
assert (
start_pg_index < len(self.resource_pool.pgs)
and local_bundle_index < self.resource_pool.pgs[start_pg_index].bundle_count
Expand Down Expand Up @@ -386,7 +394,6 @@ def get_pgs_and_bundle_indices(self) -> tuple[list[PlacementGroup], list[list[in
return pgs, bundle_indices

async def launch_servers(self):
assert self.nnodes == 1, "TRTLLMReplica doesn't support multiple nodes for single replica yet."
assert self.resource_pool.pgs is not None, "placement groups are not initialized"

pgs, bundle_indices = self.get_pgs_and_bundle_indices()
Expand Down
3 changes: 3 additions & 0 deletions verl/workers/rollout/trtllm_rollout/trtllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ async def update_weights(self, weights: dict[str, str]):


class ServerAdapter(BaseRollout):
# TODO: change to non hard-coded
_WEIGHTS_TAGS = [
"sampler",
"drafter",
Expand All @@ -268,7 +269,9 @@ class ServerAdapter(BaseRollout):
"model_extra",
"executor_extra",
"model",
"model_weights",
"draft_model",
"draft_model_weights",
]

@staticmethod
Expand Down
Loading