-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[recipe, megatron, fsdp] fix: checkpoint-engine fix trainer param offload in fully-async mode #4655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -135,8 +135,11 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | |||||||||||||||||||
| assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine | ||||||||||||||||||||
| assert hasattr(self, "_weights_info") and self._weights_info is not None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Load model to GPU | ||||||||||||||||||||
| load_start_time = time.time() | ||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||
| load_fsdp_model_to_gpu(self.actor_module_fsdp) | ||||||||||||||||||||
| load_duration = time.time() - load_start_time | ||||||||||||||||||||
|
|
||||||||||||||||||||
| from ray.util.collective import collective | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
@@ -172,14 +175,24 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | |||||||||||||||||||
| update_end_time = time.time() | ||||||||||||||||||||
| update_duration = update_end_time - update_start_time | ||||||||||||||||||||
|
|
||||||||||||||||||||
| collective.barrier(group_name=sync_group_name) | ||||||||||||||||||||
| offload_start_time = time.time() | ||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||
| offload_fsdp_model_to_cpu(self.actor_module_fsdp) | ||||||||||||||||||||
| offload_duration = time.time() - offload_start_time | ||||||||||||||||||||
|
Comment on lines
+178
to
+181
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| print( | ||||||||||||||||||||
| f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()}," | ||||||||||||||||||||
| f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout}," | ||||||||||||||||||||
| f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, " | ||||||||||||||||||||
| f" register cost {register_duration} seconds, update cost {update_duration} seconds" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||
| print( | ||||||||||||||||||||
| f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," | ||||||||||||||||||||
| f" offload model to cpu cost {offload_duration} seconds" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class DetachActorWorker(DetachNcclSync): | ||||||||||||||||||||
| def _get_actor_params(self): | ||||||||||||||||||||
|
|
||||||||||||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -140,6 +140,12 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | |||||||||||||||||||
| assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine | ||||||||||||||||||||
| assert hasattr(self, "_weights_info") and self._weights_info is not None | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Load model to GPU | ||||||||||||||||||||
| load_start_time = time.time() | ||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||
| load_megatron_model_to_gpu(self.actor_module) | ||||||||||||||||||||
| load_duration = time.time() - load_start_time | ||||||||||||||||||||
|
Comment on lines
+144
to
+147
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The measurement for
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| from ray.util.collective import collective | ||||||||||||||||||||
|
|
||||||||||||||||||||
| # Cache actor weights to CPU and measure the time taken | ||||||||||||||||||||
|
|
@@ -174,13 +180,24 @@ def sync_rollout_weights_by_checkpoint(self, sync_group_name="actor_rollout"): | |||||||||||||||||||
| update_end_time = time.time() | ||||||||||||||||||||
| update_duration = update_end_time - update_start_time | ||||||||||||||||||||
|
|
||||||||||||||||||||
| offload_start_time = time.time() | ||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||
| offload_megatron_model_to_cpu(self.actor_module) | ||||||||||||||||||||
| offload_duration = time.time() - offload_start_time | ||||||||||||||||||||
|
Comment on lines
+183
to
+186
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The measurement for
Suggested change
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| print( | ||||||||||||||||||||
| f"sync_rollout_weights_by_checkpoint finish!, rank:{torch.distributed.get_rank()}," | ||||||||||||||||||||
| f" is_actor:{self._is_actor}, is_rollout:{self._is_rollout}," | ||||||||||||||||||||
| f" total cost:{update_end_time - cache_start_time} seconds, while cache cost {cache_duration} seconds, " | ||||||||||||||||||||
| f" register cost {register_duration} seconds, update cost {update_duration} seconds" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
| if self._is_actor and self._is_offload_param: | ||||||||||||||||||||
| print( | ||||||||||||||||||||
| f"sync_rollout_weights_by_checkpoint load model to gpu cost {load_duration} seconds," | ||||||||||||||||||||
| f" offload model to cpu cost {offload_duration} seconds" | ||||||||||||||||||||
| ) | ||||||||||||||||||||
|
|
||||||||||||||||||||
|
|
||||||||||||||||||||
| class DetachActorWorker(DetachNcclSync): | ||||||||||||||||||||
| def _get_actor_params_generator(self): | ||||||||||||||||||||
|
|
||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation for measuring
load_durationis imprecise. Whenself._is_actor and self._is_offload_paramis false, it measures the time for the conditional check, resulting in a small non-zero value instead of zero. When true, it includes the overhead of the check. To measure the duration accurately, the timing logic should be contained entirely within the conditional block, andload_durationshould be initialized to 0.0.