diff --git a/recipe/one_step_off_policy/ckpt_engine_worker.py b/recipe/one_step_off_policy/ckpt_engine_worker.py new file mode 100644 index 00000000000..27af5e35529 --- /dev/null +++ b/recipe/one_step_off_policy/ckpt_engine_worker.py @@ -0,0 +1,140 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# Copyright 2025 Huawei Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import time + +import httpx +import torch +import torch.distributed +from checkpoint_engine.ps import ParameterServer, request_inference_to_update +from omegaconf import DictConfig, OmegaConf + +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils.device import ( + get_device_name, +) + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + + +class CkptEngineWorker(Worker): + def __init__(self, rank_offset, ps_world_size, inference_parallel_size, rollout_name): + super().__init__() + rank = self.rank + rank_offset + self.ps_rank = rank + self.ps_rank_offset = rank_offset + self.ps_world_size = ps_world_size + self.inference_parallel_size = inference_parallel_size + self.rollout_name = rollout_name + self.ps = ParameterServer(rank=rank, world_size=ps_world_size) + self.index = 0 + + def _init_process_group(self): + os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020" + self.ps.init_process_group(device_index=0, master_port=60010) + del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] + + def check_vllm_ready(self, uds: str | None = None): + if self.ps_rank != self.ps_rank // self.inference_parallel_size * self.inference_parallel_size: + return + retry_num = 0 + transport = None + if uds is not None: + transport = httpx.HTTPTransport(uds=uds) + while True: + try: + response = httpx.Client(transport=transport).get(f"{self.endpoint}/health", timeout=10) + response.raise_for_status() + break + except (httpx.ConnectError, httpx.HTTPStatusError) as e: + retry_num += 1 + logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}") + time.sleep(5) + + def check_sglang_ready(self, uds: str | None = None): + if self.ps_rank != self.ps_rank // self.inference_parallel_size * self.inference_parallel_size: + return + retry_num = 0 + transport = None + if uds is not None: + transport = httpx.HTTPTransport(uds=uds) + with httpx.Client(transport=transport) as client: + while True: + try: + response = client.get(f"{self.endpoint}/ping", timeout=10) + response.raise_for_status() + break + except (httpx.ConnectError, httpx.HTTPStatusError) as e: + if retry_num % 10 == 0: + logger.warning( + f"fail to check sglang ready, retry {retry_num} times, error: {e}" + ) + retry_num += 1 + time.sleep(0.1) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def set_server_addresses(self, server_addresses: list[str]): + # todo support multiple api server + self.endpoint = f"http://{server_addresses[0]}" + if self.rollout_name == "sglang": + self.check_sglang_ready() + elif self.rollout_name == "vllm": + self.check_vllm_ready() + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights_by_ckpt_engine(self): + rank = self.rank + src = rank // self.inference_parallel_size * self.inference_parallel_size + + def vllm_req_func(socket_paths: list[tuple[str, str]]) -> None: + if rank == src: + request_inference_to_update( + url=f"{self.endpoint}/collective_rpc", + socket_paths=dict(socket_paths), + ) + + def sglang_req_func(socket_paths: list[tuple[str, str]]) -> None: + if rank == src: + with httpx.Client(transport=httpx.HTTPTransport()) as client: + resp = client.post( + f"{self.endpoint}/update_weights_from_ipc", + json={ + "zmq_handles": dict(socket_paths), + "flush_cache": True, + "weight_version": None, + }, + timeout=300.0, + ) + resp.raise_for_status() + pass + + if self.rollout_name == "sglang": + req_func = sglang_req_func + elif self.rollout_name == "vllm": + req_func = vllm_req_func + + self._init_process_group() + checkpoint_name = f"sync_{self.index}" + self.ps.register_checkpoint(checkpoint_name=checkpoint_name) + self.ps.gather_metas(checkpoint_name) + self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size))) + self.index += 1 diff --git a/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml b/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml index 4c4deb485e1..301624485b5 100644 --- a/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml +++ b/recipe/one_step_off_policy/config/one_step_off_ppo_trainer.yaml @@ -20,9 +20,12 @@ actor_rollout_ref: free_cache_engine: False # Must be enabled! Otherwise, log_probs cannot be calculated. calculate_log_probs: True + engine_kwargs: + vllm: + worker_extension_cls: checkpoint_engine.worker.VllmColocateWorkerExtension # Only then will the use of log probs be correct. # And it can be used in conjunction with other rollout_correction algorithms. algorithm: rollout_correction: - bypass_mode: True \ No newline at end of file + bypass_mode: True diff --git a/recipe/one_step_off_policy/fsdp_workers.py b/recipe/one_step_off_policy/fsdp_workers.py index 0aa9dbbe004..f0e9d21a71c 100644 --- a/recipe/one_step_off_policy/fsdp_workers.py +++ b/recipe/one_step_off_policy/fsdp_workers.py @@ -18,11 +18,11 @@ import torch import torch.distributed +from checkpoint_engine.ps import ParameterServer from omegaconf import DictConfig from ray.util.collective import collective from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from recipe.one_step_off_policy.distributed_util import vllm_stateless_init_process_group from verl.single_controller.base.decorator import Dispatch, register from verl.utils.device import ( get_device_name, @@ -53,17 +53,6 @@ class DetachSync(AsyncActorRolloutRefWorker): def _get_actor_params(self): pass - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) - def create_weight_sync_group(self, master_address, master_port, rank_offset, world_size): - rank = torch.distributed.get_rank() + rank_offset - self._weight_sync_group = vllm_stateless_init_process_group( - master_address, - master_port, - rank, - world_size, - get_torch_device().current_device(), - ) - @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) def sync_rollout_weights(self): assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine @@ -127,6 +116,59 @@ async def update_weights(self, inference_engine, params): class DetachActorWorker(DetachSync): + def __init__(self, config: DictConfig, role: str, **kwargs): + ActorRolloutRefWorker.__init__(self, config, role) + + if role == "actor": + self.ps_rank_offset = kwargs.get("rank_offset", self.rank) + self.ps_world_size = kwargs.get("ps_world_size", self.world_size) + self.ps = ParameterServer(rank=self.rank, world_size=self.ps_world_size) + self.index = 0 + + def init_process_group(self): + os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020" + self.ps.init_process_group(device_index=0, master_port=60010) + del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] + + def split_tensors(self) -> dict[str, torch.Tensor]: + assert self._is_actor and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + + if self._is_actor and self._is_offload_param: + load_fsdp_model_to_gpu(self.actor_module_fsdp) + params = self._get_actor_params() + + named_tensors = {} + + world_size = self.world_size + rank = self.rank + + weights_per_rank = (len(self._weights_info) + world_size - 1) // world_size + for index, (key, _, _) in enumerate(self._weights_info): + assert key in params + tensor = params[key].full_tensor() + if index >= rank * weights_per_rank and index < (rank + 1) * weights_per_rank: + named_tensors[key] = tensor.to("cpu", non_blocking=True) + + get_torch_device().synchronize() + + return named_tensors + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights_by_ckpt_engine(self): + def req_func(socket_paths: list[tuple[str, str]]): + return + + self.init_process_group() + named_tensors = self.split_tensors() + checkpoint_name = f"sync_{self.index}" + + self.ps.register_checkpoint(checkpoint_name=checkpoint_name, named_tensors=named_tensors) + self.ps.gather_metas(checkpoint_name) + self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size))) + + self.index += 1 + def _get_actor_params(self): assert self._is_actor params = self.actor_module_fsdp.state_dict() @@ -159,8 +201,7 @@ def get_actor_weights_info(self): class DetachAsyncRolloutWorker(DetachSync): - def __init__(self, config: DictConfig, role: str): - print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") + def __init__(self, config: DictConfig, role: str, **kwargs): ActorRolloutRefWorker.__init__(self, config, role) @register(dispatch_mode=Dispatch.ONE_TO_ALL) diff --git a/recipe/one_step_off_policy/main_ppo.py b/recipe/one_step_off_policy/main_ppo.py index c24b4d01774..b405273d466 100644 --- a/recipe/one_step_off_policy/main_ppo.py +++ b/recipe/one_step_off_policy/main_ppo.py @@ -32,6 +32,8 @@ from verl.utils.config import validate_config from verl.utils.device import auto_set_ascend_device_name +from .ckpt_engine_worker import CkptEngineWorker + def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: """ @@ -69,6 +71,14 @@ def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: resource_pool_spec["rollout_pool"] = rollout_pool mapping[Role.Rollout] = "rollout_pool" + if Role.CkptEngine in roles: + assert config.rollout.n_gpus_per_node > 0, "ckpt_engine config.rollout.n_gpus_per_node must be greater than 0" + assert config.rollout.nnodes > 0, "ckpt_engine config.rollout.nnodes must be greater than 0" + # the same as rollout pool + ckpt_engine_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes + resource_pool_spec["ckpt_engine_pool"] = ckpt_engine_pool + mapping[Role.CkptEngine] = "ckpt_engine_pool" + return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) @@ -111,6 +121,7 @@ def create_role_worker_mapping(config): Role.Actor: ray.remote(DetachActorWorker), Role.Rollout: ray.remote(DetachAsyncRolloutWorker), Role.Critic: ray.remote(CriticWorker), + Role.CkptEngine: ray.remote(CkptEngineWorker), } if config.reward_model.enable: @@ -140,6 +151,9 @@ def run(self, config): from verl.utils.fs import copy_to_local + if os.environ.get("ASCEND_RT_VISIBLE_DEVICES", None) is not None: + del os.environ["ASCEND_RT_VISIBLE_DEVICES"] + print(f"TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") pprint(OmegaConf.to_container(config, resolve=True)) diff --git a/recipe/one_step_off_policy/megatron_workers.py b/recipe/one_step_off_policy/megatron_workers.py index c2a2407939e..8a18fd9530f 100644 --- a/recipe/one_step_off_policy/megatron_workers.py +++ b/recipe/one_step_off_policy/megatron_workers.py @@ -18,6 +18,7 @@ import torch import torch.distributed +from checkpoint_engine.ps import ParameterServer from omegaconf import DictConfig from ray.util.collective import collective @@ -120,6 +121,58 @@ async def update_weights(self, inference_engine, params): class DetachActorWorker(DetachSync): + def __init__(self, config: DictConfig, role: str, **kwargs): + ActorRolloutRefWorker.__init__(self, config, role) + + if role == "actor": + self.ps_rank_offset = kwargs.get("rank_offset", self.rank) + self.ps_world_size = kwargs.get("ps_world_size", self.world_size) + self.ps = ParameterServer(rank=self.rank, world_size=self.ps_world_size) + self.index = 0 + + def init_process_group(self): + os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] = "61020" + self.ps.init_process_group(device_index=0, master_port=60010) + del os.environ["HCCL_NPU_SOCKET_PORT_RANGE"] + + def split_tensors(self) -> dict[str, torch.Tensor]: + assert self._is_actor and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + + params_generator = self._get_actor_params_generator() if self._is_actor else None + + if self._is_actor and self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module) + + named_tensors = {} + + world_size = self.world_size + rank = self.rank + + weights_per_rank = (len(self._weights_info) + world_size - 1) // world_size + for index, (key, tensor) in enumerate(params_generator): + if index >= rank * weights_per_rank and index < (rank + 1) * weights_per_rank: + named_tensors[key] = tensor.to("cpu", non_blocking=True) + + get_torch_device().synchronize() + + return named_tensors + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights_by_ckpt_engine(self): + def req_func(socket_paths: list[tuple[str, str]]): + return + + self.init_process_group() + named_tensors = self.split_tensors() + checkpoint_name = f"sync_{self.index}" + + self.ps.register_checkpoint(checkpoint_name=checkpoint_name, named_tensors=named_tensors) + self.ps.gather_metas(checkpoint_name) + self.ps.update(checkpoint_name, req_func, ranks=list(range(self.ps_rank_offset, self.ps_world_size))) + + self.index += 1 + @register(dispatch_mode=Dispatch.ONE_TO_ALL) def _get_actor_params_generator(self): assert self._is_actor @@ -160,7 +213,7 @@ def get_actor_weights_info(self): class DetachAsyncRolloutWorker(DetachSync): - def __init__(self, config: DictConfig, role: str): + def __init__(self, config: DictConfig, role: str, **kwargs): print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") ActorRolloutRefWorker.__init__(self, config, role) diff --git a/recipe/one_step_off_policy/ray_trainer.py b/recipe/one_step_off_policy/ray_trainer.py index c3890f61bb9..ab290f48d38 100644 --- a/recipe/one_step_off_policy/ray_trainer.py +++ b/recipe/one_step_off_policy/ray_trainer.py @@ -127,6 +127,9 @@ def __init__( if config.algorithm.use_kl_in_reward: self.kl_ctrl_in_reward = core_algos.get_kl_controller(config.algorithm.kl_ctrl) + self.rank_offset = config.trainer.n_gpus_per_node * config.trainer.nnodes + self.ps_world_size = self.rank_offset + config.rollout.n_gpus_per_node * config.rollout.nnodes + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) def _validate(self): @@ -149,7 +152,8 @@ def init_workers(self): self._init_async_rollout_manager() def _init_resource_pools(self): - self.resource_pool_manager.create_resource_pool() + additional = {"ckpt_engine_pool": {"CPU": 1, "NPU": 0.2}, "rollout_pool": {"CPU": 1, "NPU": 0.8}} + self.resource_pool_manager.create_resource_pool(additional=additional) self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} @@ -158,6 +162,20 @@ def _create_worker_classes(self): self._create_critic_class() self._create_reference_policy_class() self._create_reward_model_class() + self._create_ckpt_engine_class() + + def _create_ckpt_engine_class(self): + # create ckpt engine + if True: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.CkptEngine) + ckpt_engine_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[Role.CkptEngine], + rank_offset=self.rank_offset, + ps_world_size=self.ps_world_size, + inference_parallel_size=self.config.actor_rollout_ref.rollout.tensor_model_parallel_size, + rollout_name=self.config.actor_rollout_ref.rollout.name + ) + self.resource_pool_to_cls[resource_pool][str(Role.CkptEngine)] = ckpt_engine_cls def _create_actor_rollout_classes(self): for role in [Role.Actor, Role.Rollout]: @@ -166,6 +184,8 @@ def _create_actor_rollout_classes(self): cls=self.role_worker_mapping[role], config=self.config.actor_rollout_ref, role=str(role), + rank_offset=self.rank_offset, + ps_world_size=self.ps_world_size, ) self.resource_pool_to_cls[resource_pool][str(role)] = role_cls @@ -249,26 +269,10 @@ def _init_models(self): self.rollout_wg = self.all_wg[str(Role.Rollout)] self.actor_wg.init_model() self.rollout_wg.init_model() + self.ckpt_engine_wg = self.all_wg[str(Role.CkptEngine)] self.actor_rollout_wg = self.actor_wg weights_info = self.actor_wg.get_actor_weights_info()[0] self.rollout_wg.set_actor_weights_info(weights_info) - self._create_weight_sync_group() - - def _create_weight_sync_group(self): - # TODO: NPU support - from verl.utils.device import get_nccl_backend - - actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers - n_workers = len(actor_rollout_workers) - - # Create Ray collective group for fallback communication - collective.create_collective_group( - actor_rollout_workers, - n_workers, - list(range(0, n_workers)), - backend=get_nccl_backend(), - group_name="actor_rollout", - ) def _init_async_rollout_manager(self): # create async rollout manager and request scheduler @@ -286,9 +290,11 @@ def _init_async_rollout_manager(self): config=self.config, worker_group=self.rollout_wg, rm_resource_pool=rm_resource_pool ) + ray.get(self.ckpt_engine_wg.set_server_addresses(self.async_rollout_manager.server_addresses)) + def sync_rollout_weights(self): - self.actor_wg.sync_rollout_weights() - ray.get(self.rollout_wg.sync_rollout_weights()) + self.actor_wg.sync_rollout_weights_by_ckpt_engine() + ray.get(self.ckpt_engine_wg.sync_rollout_weights_by_ckpt_engine()) def _create_continuous_iterator(self): """ diff --git a/recipe/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6_ckpt_engine.sh b/recipe/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6_ckpt_engine.sh new file mode 100644 index 00000000000..a72b16aa5cd --- /dev/null +++ b/recipe/one_step_off_policy/shell/grpo_0.6b_gsm8k_fsdp2_2_6_ckpt_engine.sh @@ -0,0 +1,68 @@ +set -x + +project_name='GRPO' +exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-one-step-off-2-6' + +# Paths +PREFIX=${PREFIX:-"/ssd_2/test"} +MODEL_PATH=${MODEL_PATH:-"${PREFIX}/qwen3-0.6b"} +CKPTS_DIR=${CKPTS_DIR:-"${PREFIX}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${PREFIX}/gsm8k/train.parquet"} +TEST_FILE=${TEST_FILE:-"${PREFIX}/gsm8k/test.parquet"} + +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=2 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +export HYDRA_FULL_ERROR=1 + +python3 -m recipe.one_step_off_policy.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.train_batch_size=1152 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + data.filter_overlong_prompts=True \ + data.truncation='error' \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.model.trust_remote_code=True \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=192 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.rollout.load_format=safetensors \ + actor_rollout_ref.rollout.layered_summon=True \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.use_kl_in_reward=False \ + trainer.critic_warmup=0 \ + trainer.val_before_train=True \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.device=npu \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=2 \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" diff --git a/verl/single_controller/ray/base.py b/verl/single_controller/ray/base.py index f1bdb553d5f..3aa45533f34 100644 --- a/verl/single_controller/ray/base.py +++ b/verl/single_controller/ray/base.py @@ -102,6 +102,7 @@ def __init__( max_colocate_count: int = 10, detached=False, accelerator_type: Optional[str] = None, + custom_bundle: Optional[dict] = None, ) -> None: super().__init__(process_on_nodes, max_colocate_count) self.use_gpu = use_gpu @@ -110,6 +111,7 @@ def __init__( self.pgs = None self.detached = detached self.accelerator_type = accelerator_type + self.custom_bundle = custom_bundle def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="cuda"): if self.pgs is not None: @@ -124,11 +126,14 @@ def get_placement_groups(self, strategy="STRICT_PACK", name=None, device_name="c elif device_name == "cuda": device_name = "GPU" - bundle = {"CPU": self.max_colocate_count} - if self.use_gpu: - bundle[device_name] = 1 - if self.accelerator_type is not None: - bundle[self.accelerator_type] = 1e-4 + if self.custom_bundle is not None: + bundle = self.custom_bundle + else: + bundle = {"CPU": self.max_colocate_count} + if self.use_gpu: + bundle[device_name] = 1 + if self.accelerator_type is not None: + bundle[self.accelerator_type] = 1e-4 pg_scheme = [[bundle.copy() for _ in range(process_count)] for process_count in self._store] lifetime = "detached" if self.detached else None @@ -528,8 +533,8 @@ def _create_worker(self, rank, pg_idx, pg, local_rank, resource_pool, ray_cls_wi world_size = resource_pool.world_size use_gpu = resource_pool.use_gpu local_world_size = resource_pool.store[0] - num_gpus = 1 / resource_pool.max_colocate_count - + custom_bundle = resource_pool.custom_bundle + num_gpus = custom_bundle.get("NPU", 1.0) if custom_bundle is not None else 1 / resource_pool.max_colocate_count # we pass in environment variable at option so that Worker can use environment variable to set env_vars = { "WORLD_SIZE": str(world_size), diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 4558e750cc3..b015dc57925 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -76,7 +76,7 @@ class ResourcePoolManager: mapping: dict[Role, str] resource_pool_dict: dict[str, RayResourcePool] = field(default_factory=dict) - def create_resource_pool(self): + def create_resource_pool(self, additional: Optional[dict[str, dict]] = None): """Create Ray resource pools for distributed training. Initializes resource pools based on the resource pool specification, @@ -89,11 +89,17 @@ def create_resource_pool(self): # For FSDP backend, using max_colocate_count=3: actor_critic_ref, rollout, reward model (optional) # For Megatron backend, we recommend using max_colocate_count>1 # that can utilize different WorkerGroup for differnt models + bundle = additional.get(resource_pool_name, None) if additional is not None else None resource_pool = RayResourcePool( - process_on_nodes=process_on_nodes, use_gpu=True, max_colocate_count=3, name_prefix=resource_pool_name + process_on_nodes=process_on_nodes, + use_gpu=True, + max_colocate_count=3, + name_prefix=resource_pool_name, + custom_bundle=bundle, ) self.resource_pool_dict[resource_pool_name] = resource_pool + # todo fix it when using ckpt-engine self._check_resource_available() def get_resource_pool(self, role: Role) -> RayResourcePool: diff --git a/verl/trainer/ppo/utils.py b/verl/trainer/ppo/utils.py index c53e617e95f..452ffff75f5 100644 --- a/verl/trainer/ppo/utils.py +++ b/verl/trainer/ppo/utils.py @@ -36,6 +36,7 @@ class Role(Enum): RewardModel = 5 ActorRolloutRef = 6 Env = 7 + CkptEngine = 8 def __str__(self): return self._get_role_string() @@ -49,6 +50,7 @@ def _get_role_string(self): Role.RefPolicy: "ref", Role.RewardModel: "rm", Role.ActorRolloutRef: "actor_rollout_ref", + Role.CkptEngine: "ckpt_engine", } return role_mapping.get(self, self.name.lower()) @@ -62,6 +64,7 @@ def from_string(cls, name: str): "ref": cls.RefPolicy, "rm": cls.RewardModel, "actor_rollout_ref": cls.ActorRolloutRef, + "ckpt_engine": cls.CkptEngine, } role = string_mapping.get(name.lower()) if role is None: diff --git a/verl/utils/device.py b/verl/utils/device.py index 2a21a31b668..5510c26ece5 100644 --- a/verl/utils/device.py +++ b/verl/utils/device.py @@ -78,7 +78,9 @@ def get_nccl_backend() -> str: Returns: nccl backend type string. """ - if is_npu_available: + # is_npu_available is False because envrionment variables ASCEND_RT_VISABLE_DEVICES is empty + # if is_npu_available: + if is_torch_npu_available(): return "hccl" else: # default to nccl diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 00a8f75ef9d..bdf89432647 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -697,6 +697,7 @@ async def launch_servers(self): soft=False, ), name=name, + runtime_env={"env_vars": {"VLLM_SERVER_DEV_MODE": "1"}}, ).remote( config=self.config, model_config=self.model_config,