Skip to content
Open
7 changes: 7 additions & 0 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,9 @@ def zeropp_loco_param(self):
def zero_log_trace_cache_warnings(self):
return self._config.zero_config.log_trace_cache_warnings

def zero_allgather_single_param(self):
return self._config.zero_config.allgather_single_param

def is_sanity_checks_enabled(self):
return self._config.zero_config.enable_sanity_checks

Expand Down Expand Up @@ -1888,6 +1891,10 @@ def _configure_zero_optimizer(self, optimizer):
if mics_shard_size > 0:
return self._return_mics_optimizer(optimizer, timers)

if self.zero_allgather_single_param():
log_dist(f"If zero_allgather_single_param is True, set prefetch_bucket_size to 1", ranks=[0])
self._config.zero_config.prefetch_bucket_size = 1

log_dist(f'Creating {model_dtype} ZeRO stage {zero_stage} optimizer', ranks=[0])
from deepspeed.runtime.zero.stage3 import DeepSpeedZeroOptimizer_Stage3
from deepspeed.runtime.superoffload.superoffload_stage3 import SuperOffloadOptimizer_Stage3
Expand Down
8 changes: 8 additions & 0 deletions deepspeed/runtime/zero/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
"stage3_module_granularity_threshold": 0,
"allgather_partitions": [true|false],
"use_multi_rank_bucket_allreduce": [true|false],
"stage3_allgather_single_param": [true|false],
"allgather_bucket_size": 500000000,
"reduce_scatter": [true|false],
"contiguous_gradients" : [true|false]
Expand Down Expand Up @@ -268,6 +269,13 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
the overhead of concatenation and slicing on the host.
"""

allgather_single_param: bool = Field(default=False, alias="stage3_allgather_single_param")
"""
Enables allgather on individual parameters instead of parameter lists in stage3.
Reduces peak memory usage and improves performance in high memory pressure scenarios
by minimizing temporary buffers required for parameter gathering.
"""

stage3_gather_fp16_weights_on_model_save: bool = Field(False,
json_schema_extra={
"deprecated": True,
Expand Down
Loading