Skip to content

Commit 7d9a2f2

Browse files
Improve leaf module interface (enable via config, relax matching criteria, add document, etc.) (#7604)
This PR improves the usability of the leaf module feature. Here are the changes: - Allow enabling the leaf module via both the DeepSpeed config and APIs. - Relax matching criteria to support class-based matching. - Support multiple ways of specifying the target module: class, class name (with or without package name), module name, or suffix. - Add documentation to the training guide, including config snippets and explanations of default behavior. - Add default classes (e.g., Mixtral, Qwen2/Qwen3) that automatically enable the leaf module feature. (Welcoming requests to add more classes) --------- Signed-off-by: Masahiro Tanaka <[email protected]> Co-authored-by: Olatunji Ruwase <[email protected]>
1 parent 82a9db7 commit 7d9a2f2

File tree

7 files changed

+492
-22
lines changed

7 files changed

+492
-22
lines changed

deepspeed/runtime/engine.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
from deepspeed.runtime import lr_schedules
7777
from deepspeed.utils import groups
7878
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
79+
from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config
7980
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
8081
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
8182
STEP_MICRO_TIMER, \
@@ -1293,6 +1294,7 @@ def _set_client_model(self, model):
12931294

12941295
def _configure_distributed_model(self, model):
12951296
self._set_client_model(model)
1297+
apply_zero_leaf_module_config(self.module, getattr(self._config.zero_config, "leaf_module", None))
12961298
is_zero_init_model = self.zero_optimization_partition_weights() and any(
12971299
[hasattr(param, "ds_id") for param in self.module.parameters()])
12981300

deepspeed/runtime/zero/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from deepspeed.utils import logger
1212
from .offload_config import DeepSpeedZeroOffloadParamConfig, DeepSpeedZeroOffloadOptimizerConfig, OffloadDeviceEnum
1313
from deepspeed.runtime.zenflow.zenflow_config import ZenFlowConfig
14+
from .leaf_module_config import DeepSpeedZeroLeafModuleConfig
1415

1516
# ZeRO optimization. By default, this optimization is not enabled.
1617
# Users have to configure the desired optimization (0 means disabled) in params.json as below example:
@@ -356,6 +357,11 @@ class DeepSpeedZeroConfig(DeepSpeedConfigModel):
356357
Enable internal sanity checks, which could be useful for debugging
357358
"""
358359

360+
leaf_module: DeepSpeedZeroLeafModuleConfig = Field(default_factory=DeepSpeedZeroLeafModuleConfig)
361+
"""
362+
Configuration for modules that should be treated as ZeRO3 leaf modules.
363+
"""
364+
359365
# Validators
360366
@model_validator(mode="after")
361367
def overlap_comm_valid(self):
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
# DeepSpeed Team
5+
6+
from typing import List
7+
from pydantic import Field, model_validator
8+
9+
from deepspeed.runtime.config_utils import DeepSpeedConfigModel
10+
11+
DEFAULT_LEAF_MODULE_CLASSES: List[str] = [
12+
"transformers.models.mixtral.modeling_mixtral.MixtralSparseMoeBlock",
13+
"transformers.models.qwen2_moe.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock",
14+
"transformers.models.qwen3_moe.modeling_qwen3_moe.Qwen3MoeSparseMoeBlock",
15+
]
16+
DEFAULT_LEAF_MODULE_NAMES: List[str] = []
17+
DEFAULT_LEAF_MODULE_NAME_SUFFIXES: List[str] = []
18+
19+
20+
class DeepSpeedZeroLeafModuleConfig(DeepSpeedConfigModel):
21+
"""Configuration for ZeRO leaf modules that should bypass hook installation."""
22+
23+
classes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_CLASSES))
24+
names: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAMES))
25+
name_suffixes: List[str] = Field(default_factory=lambda: list(DEFAULT_LEAF_MODULE_NAME_SUFFIXES))
26+
27+
@model_validator(mode="before")
28+
def _coerce_container_types(cls, values):
29+
if values is None:
30+
return {}
31+
if isinstance(values, dict):
32+
coerced = dict(values)
33+
for key in ("classes", "names", "name_suffixes"):
34+
if key in coerced and isinstance(coerced[key], str):
35+
coerced[key] = [coerced[key]]
36+
return coerced
37+
raise TypeError("leaf_module configuration must be a mapping of fields to values")
38+
39+
@model_validator(mode="after")
40+
def _validate_entries(self):
41+
normalized_classes = [str(cls) for cls in self.classes]
42+
normalized_names = [str(name) for name in self.names]
43+
normalized_suffixes = [str(suffix) for suffix in self.name_suffixes]
44+
45+
deduped_classes = list(dict.fromkeys(normalized_classes))
46+
deduped_names = list(dict.fromkeys(normalized_names))
47+
deduped_suffixes = list(dict.fromkeys(normalized_suffixes))
48+
49+
object.__setattr__(self, "classes", deduped_classes)
50+
object.__setattr__(self, "names", deduped_names)
51+
object.__setattr__(self, "name_suffixes", deduped_suffixes)
52+
return self

deepspeed/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from .tensor_fragment import safe_get_local_fp32_param, safe_get_local_grad, safe_get_local_optimizer_state
1818
from .tensor_fragment import safe_set_local_fp32_param, safe_set_local_grad, safe_set_local_optimizer_state
1919
from .tensor_fragment import safe_update_full_grad_vectorized
20-
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module
20+
from .z3_leaf_module import set_z3_leaf_modules, unset_z3_leaf_modules, get_z3_leaf_modules, z3_leaf_module, z3_leaf_parameter, set_z3_leaf_module, set_z3_leaf_modules_by_name, set_z3_leaf_modules_by_suffix
2121
from .mixed_precision_linkage import link_hp_params, lazy_init_hp_params_optimizer_state
2222
from deepspeed.runtime.dataloader import RepeatingLoader
2323
from .numa import get_numactl_cmd

deepspeed/utils/z3_leaf_module.py

Lines changed: 172 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
# DeepSpeed Team
55

66
import torch
7-
from typing import List, Type, Union
7+
from typing import List, Tuple, Type, Union, Optional, TYPE_CHECKING
8+
9+
from .logging import logger
10+
11+
if TYPE_CHECKING:
12+
from deepspeed.runtime.zero.leaf_module_config import DeepSpeedZeroLeafModuleConfig
813

914

1015
def z3_leaf_module(model: torch.nn.Module) -> bool:
@@ -44,50 +49,201 @@ def set_z3_leaf_module(model: torch.nn.Module, flag: bool):
4449
model._z3_leaf = flag
4550

4651

47-
def _do_set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type], List[str]],
48-
flag: bool) -> List[torch.nn.Module]:
49-
assert all(isinstance(module_class, (type, str) ) for module_class in leaf_module_classes), \
52+
def _fully_qualified_class_name(module: torch.nn.Module) -> str:
53+
cls = module.__class__
54+
return f"{cls.__module__}.{cls.__qualname__}"
55+
56+
57+
def _do_set_z3_leaf_modules(model: torch.nn.Module,
58+
leaf_module_classes: Union[List[Type], List[str]],
59+
flag: bool,
60+
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
61+
assert all(isinstance(module_class, (type, str)) for module_class in leaf_module_classes), \
5062
f'leaf_module_classes must be a list of types or names, got {leaf_module_classes}'
5163

52-
leaf_modules = []
64+
leaf_modules: List[torch.nn.Module] = []
5365

54-
def _set_z3_leaf_flag(model: torch.nn.Module):
66+
def _set_z3_leaf_flag(module_instance: torch.nn.Module):
5567
nonlocal leaf_modules
5668
for module in leaf_module_classes:
57-
if (isinstance(module, type) and model.__class__ == module) or \
58-
(isinstance(module, str) and model.__class__.__name__ == module):
59-
model._z3_leaf = flag
60-
leaf_modules.append(model)
69+
if isinstance(module, type) and isinstance(module_instance, module):
70+
module_instance._z3_leaf = flag
71+
leaf_modules.append(module_instance)
72+
break
73+
74+
if isinstance(module, str):
75+
if (module_instance.__class__.__name__ == module
76+
or _fully_qualified_class_name(module_instance) == module):
77+
module_instance._z3_leaf = flag
78+
leaf_modules.append(module_instance)
79+
break
6180

6281
model.apply(_set_z3_leaf_flag)
6382

64-
if len(leaf_modules) == 0:
83+
if len(leaf_modules) == 0 and raise_if_not_found:
6584
raise ValueError(f'No modules of type {leaf_module_classes} found in model {model}')
6685

6786
return leaf_modules
6887

6988

70-
def set_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: Union[List[Type],
71-
List[str]]) -> List[torch.nn.Module]:
89+
def set_z3_leaf_modules_by_name(model: torch.nn.Module,
90+
module_names: List[str],
91+
flag: bool = True,
92+
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
93+
"""Sets a leaf flag for modules referenced by their names in ``model.named_modules()``.
94+
Args:
95+
model (torch.nn.Module): The model containing the modules to update.
96+
module_names (List[str]): Module names as returned by ``named_modules()``.
97+
flag (bool): Desired flag state.
98+
raise_if_not_found (bool): Whether to raise when no module matches a provided name.
99+
Returns:
100+
Tuple[List[torch.nn.Module], List[str]]: Matched modules and missing module names.
101+
"""
102+
modules_by_name = dict(model.named_modules())
103+
leaf_modules: List[torch.nn.Module] = []
104+
missing: List[str] = []
105+
106+
for name in module_names:
107+
module = modules_by_name.get(name)
108+
if module is None:
109+
missing.append(name)
110+
continue
111+
module._z3_leaf = flag
112+
leaf_modules.append(module)
113+
114+
if missing and raise_if_not_found:
115+
raise ValueError(f'No modules named {missing} found in model {model}')
116+
117+
return leaf_modules, missing
118+
119+
120+
def set_z3_leaf_modules_by_suffix(model: torch.nn.Module,
121+
module_name_suffixes: List[str],
122+
flag: bool = True,
123+
raise_if_not_found: bool = True) -> Tuple[List[torch.nn.Module], List[str]]:
124+
"""Sets a leaf flag for modules referenced by suffixes of ``model.named_modules()`` names."""
125+
modules_by_name = dict(model.named_modules())
126+
leaf_modules: List[torch.nn.Module] = []
127+
missing: List[str] = []
128+
seen_ids = set()
129+
130+
for suffix in module_name_suffixes:
131+
matched = False
132+
for name, module in modules_by_name.items():
133+
if name.endswith(suffix):
134+
module._z3_leaf = flag
135+
module_id = id(module)
136+
if module_id not in seen_ids:
137+
seen_ids.add(module_id)
138+
leaf_modules.append(module)
139+
matched = True
140+
if not matched:
141+
missing.append(suffix)
142+
143+
if missing and raise_if_not_found:
144+
raise ValueError(f'No modules matching suffixes {missing} found in model {model}')
145+
146+
return leaf_modules, missing
147+
148+
149+
def set_z3_leaf_modules(model: torch.nn.Module,
150+
leaf_module_classes: Union[List[Type], List[str]],
151+
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
72152
"""Sets a flag within a module in `model` to instruct ZeRO3 to stop setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
73153
This is particularly useful in the context of Mixture of Experts (MoE) models. In MoE models, the computation order of experts varies across forward passes. This variability can disrupt ZeRO3's functionality, as ZeRO3 relies on tracking the computation order of modules to prefetch parameters efficiently. By designating a module as a 'leaf' node, ZeRO3 will prefetch parameters for all child modules upon entering the module.
74154
Another scenario where this functionality is beneficial is in models with excessively fine-grained nested modules, where it helps to avoid the overhead associated with hooks.
75155
Args:
76156
model (torch.nn.Module): The model to which the leaf module flag will be applied.
77157
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
158+
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
159+
match a module inside ``model``.
78160
Returns:
79161
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
80162
"""
81-
return _do_set_z3_leaf_modules(model, leaf_module_classes, True)
163+
return _do_set_z3_leaf_modules(model, leaf_module_classes, True, raise_if_not_found)
82164

83165

84-
def unset_z3_leaf_modules(model: torch.nn.Module, leaf_module_classes: List[Type]) -> List[torch.nn.Module]:
166+
def unset_z3_leaf_modules(model: torch.nn.Module,
167+
leaf_module_classes: List[Type],
168+
raise_if_not_found: bool = True) -> List[torch.nn.Module]:
85169
"""Unsets a flag within a module in `model` to instruct ZeRO3 to resume setting hooks recursively when it encounters a module class listed in `leaf_module_classes`.
86170
See `set_z3_leaf_modules` for more details.
87171
Args:
88172
model (torch.nn.Module): The model to which the leaf module flag will be applied.
89173
leaf_module_classes (Union[List[Type], List[str]]): A list of module classes that should be flagged as 'leaf' modules.
174+
raise_if_not_found (bool): Whether to raise a ``ValueError`` when none of the provided classes
175+
match a module inside ``model``.
90176
Returns:
91177
List[torch.nn.Module]: A list of modules that match the module classes in `leaf_module_classes`.
92178
"""
93-
return _do_set_z3_leaf_modules(model, leaf_module_classes, False)
179+
return _do_set_z3_leaf_modules(model, leaf_module_classes, False, raise_if_not_found)
180+
181+
182+
def apply_zero_leaf_module_config(model: torch.nn.Module,
183+
leaf_cfg: Optional["DeepSpeedZeroLeafModuleConfig"]) -> List[torch.nn.Module]:
184+
"""Apply ZeRO leaf module configuration to ``model``.
185+
186+
Args:
187+
model (torch.nn.Module): Root module to update.
188+
leaf_cfg (DeepSpeedZeroLeafModuleConfig | None): Parsed configuration. If ``None``
189+
no changes are applied.
190+
191+
Returns:
192+
List[torch.nn.Module]: Modules flagged as leaves.
193+
"""
194+
if leaf_cfg is None:
195+
return []
196+
197+
from deepspeed.runtime.zero.leaf_module_config import (
198+
DEFAULT_LEAF_MODULE_CLASSES,
199+
DEFAULT_LEAF_MODULE_NAMES,
200+
DEFAULT_LEAF_MODULE_NAME_SUFFIXES,
201+
)
202+
203+
matched_modules: List[torch.nn.Module] = []
204+
matched_ids = set()
205+
206+
customized_classes = leaf_cfg.classes != DEFAULT_LEAF_MODULE_CLASSES
207+
customized_names = leaf_cfg.names != DEFAULT_LEAF_MODULE_NAMES
208+
customized_suffixes = leaf_cfg.name_suffixes != DEFAULT_LEAF_MODULE_NAME_SUFFIXES
209+
210+
if leaf_cfg.classes:
211+
class_matched = set_z3_leaf_modules(model, leaf_cfg.classes, raise_if_not_found=False)
212+
for module in class_matched:
213+
module_id = id(module)
214+
if module_id not in matched_ids:
215+
matched_ids.add(module_id)
216+
matched_modules.append(module)
217+
218+
if leaf_cfg.names:
219+
name_matched, missing_names = set_z3_leaf_modules_by_name(model,
220+
leaf_cfg.names,
221+
flag=True,
222+
raise_if_not_found=False)
223+
for module in name_matched:
224+
module_id = id(module)
225+
if module_id not in matched_ids:
226+
matched_ids.add(module_id)
227+
matched_modules.append(module)
228+
229+
if missing_names and customized_names:
230+
logger.warning(f"ZeRO leaf module configuration contains unknown module names: {missing_names}")
231+
232+
if leaf_cfg.name_suffixes:
233+
suffix_matched, missing_suffixes = set_z3_leaf_modules_by_suffix(model,
234+
leaf_cfg.name_suffixes,
235+
flag=True,
236+
raise_if_not_found=False)
237+
for module in suffix_matched:
238+
module_id = id(module)
239+
if module_id not in matched_ids:
240+
matched_ids.add(module_id)
241+
matched_modules.append(module)
242+
243+
if missing_suffixes and customized_suffixes:
244+
logger.warning(f"ZeRO leaf module configuration contains unmatched module suffixes: {missing_suffixes}")
245+
246+
if not matched_modules and (customized_classes or customized_names or customized_suffixes):
247+
logger.warning("ZeRO leaf module configuration did not match any modules; hooks will be applied as usual")
248+
249+
return matched_modules

0 commit comments

Comments
 (0)