diff --git a/recipe/fully_async_policy/fsdp_workers.py b/recipe/fully_async_policy/fsdp_workers.py index 4d172fb657e..7e1296287d9 100644 --- a/recipe/fully_async_policy/fsdp_workers.py +++ b/recipe/fully_async_policy/fsdp_workers.py @@ -135,8 +135,11 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine assert hasattr(self, "_weights_info") and self._weights_info is not None + # Load model to GPU + load_start_time = time.time() if self._is_actor and self._is_offload_param: load_fsdp_model_to_gpu(self.actor_module_fsdp) + load_duration = time.time() - load_start_time from ray.util.collective import collective @@ -172,7 +175,11 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): update_end_time = time.time() update_duration = update_end_time - update_start_time - collective.barrier(group_name=sync_group_name) + offload_start_time = time.time() + if self._is_actor and self._is_offload_param: + offload_fsdp_model_to_cpu(self.actor_module_fsdp) + offload_duration = time.time() - offload_start_time + print( f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()}," f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout}," @@ -180,6 +187,12 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): f" register cost {register_duration} seconds, update cost {update_duration} seconds" ) + if self._is_actor and self._is_offload_param: + print( + f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," + f" offload model to cpu cost {offload_duration} seconds" + ) + class DetachActorWorker(DetachNcclSync): def _get_actor_params(self): diff --git a/recipe/fully_async_policy/megatron_worker.py b/recipe/fully_async_policy/megatron_worker.py index f9f2c932a4f..045aa77f750 100644 --- a/recipe/fully_async_policy/megatron_worker.py +++ b/recipe/fully_async_policy/megatron_worker.py @@ -140,6 +140,12 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine assert hasattr(self, "_weights_info") and self._weights_info is not None + # Load model to GPU + load_start_time = time.time() + if self._is_actor and self._is_offload_param: + load_megatron_model_to_gpu(self.actor_module) + load_duration = time.time() - load_start_time + from ray.util.collective import collective # Cache actor weights to CPU and measure the time taken @@ -174,6 +180,11 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): update_end_time = time.time() update_duration = update_end_time - update_start_time + offload_start_time = time.time() + if self._is_actor and self._is_offload_param: + offload_megatron_model_to_cpu(self.actor_module) + offload_duration = time.time() - offload_start_time + print( f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()}," f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout}," @@ -181,6 +192,12 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): f" register cost {register_duration} seconds, update cost {update_duration} seconds" ) + if self._is_actor and self._is_offload_param: + print( + f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," + f" offload model to cpu cost {offload_duration} seconds" + ) + class DetachActorWorker(DetachNcclSync): def _get_actor_params_generator(self):