Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion deepspeed/inference/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from ..module_inject import LinearAllreduce, LinearLayer, Normalize, ReplaceWithTensorSlicing
from deepspeed.accelerator import get_accelerator
from ..module_inject.policy import TransformerPolicy
from ..module_inject.auto_tp import AutoTP
from ..module_inject.auto_tp import AutoTP, Loading

from ..module_inject.replace_policy import generic_policies
from ..module_inject.auto_tp_model_utils import build_bloom_alibi_tensor, build_mpt_atten_bias_tensor, build_mpt_alibi_tensor, get_alibi_mask
Expand Down Expand Up @@ -363,7 +363,14 @@ def load_module_recursive(module, prefix='', level=0):
child = Normalize(dim=child.weight.ds_shape[-1], dtype=child.weight.dtype, eps=child.eps)
setattr(module, name, child)
load(child, self.sd, prefix + name + '.')
# Load buffers for this module
if len(child._buffers) != 0:
Loading.load_buffer(child, self.sd, checking_key, self.mp_group)
else:
checking_key = prefix + name + '.'
# Load buffers for non-policy modules
if len(child._buffers) != 0:
Loading.load_buffer(child, self.sd, checking_key, self.mp_group)
load_module_recursive(child, prefix if level == 0 else prefix + name + '.', level + 1)

load_module_recursive(r_module)
Expand Down
14 changes: 11 additions & 3 deletions deepspeed/module_inject/auto_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,22 @@ def is_load_module(module):
]
return module.__class__ in load_layers or module._get_name() in load_layer_names

def load_buffer(module, state_dict, prefix):
def load_buffer(module, state_dict, prefix, mp_group=None):
for name in module._buffers.keys():
if module._buffers[name].data.is_meta:
module._buffers[name] = torch.nn.parameter.Parameter(
data=torch.empty_like(module._buffers[name].data, device="cpu"),
requires_grad=module._buffers[name].data.requires_grad)
if prefix + name in state_dict.keys():
module._buffers[name].data.copy_(state_dict[prefix + name])
# Buffers are typically not sharded across devices, so we copy the full buffer
# to all devices. Ensure the buffer data is moved to the correct device.
buffer_data = state_dict[prefix + name]
if not buffer_data.is_meta:
# Move buffer data to the same device as the module's buffer
target_device = module._buffers[name].data.device
if buffer_data.device != target_device:
buffer_data = buffer_data.to(target_device)
module._buffers[name].data.copy_(buffer_data)

def load(module, state_dict, prefix, mp_group=None):
mp_replace = ReplaceWithTensorSlicing(mp_group=mp_group)
Expand Down Expand Up @@ -461,7 +469,7 @@ def _replace_module(self, r_module, prev_name='', prev_class_name=''):
else:
continue
if len(child._buffers) != 0 and self.state_dict is not None:
Loading.load_buffer(child, self.state_dict, checking_key)
Loading.load_buffer(child, self.state_dict, checking_key, self.mp_group)
if child.__class__ in self.linear_policies:
setattr(r_module, name, self.linear_policies[child.__class__](child, prev_name + '.' + name,
self.conv_linear_layer))
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/module_inject/replace_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ def _replace_module(model, policies, prefix='', layer_id=0, level_id=0, state_di
else:
continue
if len(child._buffers) != 0 and state_dict is not None:
Loading.load_buffer(child, state_dict, checking_key)
Loading.load_buffer(child, state_dict, checking_key, mp_group=None)
_, layer_id = _replace_module(child,
policies,
prefix if level_id == 0 and skip_level_0_prefix(model, state_dict) else \
Expand Down
Loading