Skip to content

Commit d55f736

Browse files
committed
adaptor func _allgather_params
Signed-off-by: aeeeeeep <[email protected]>
1 parent 77a51f7 commit d55f736

File tree

1 file changed

+102
-54
lines changed

1 file changed

+102
-54
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 102 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1994,78 +1994,126 @@ def _allgather_params(self, param_list, hierarchy=0):
19941994
if len(param_list) == 0:
19951995
return
19961996

1997-
partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
1997+
if self.allgather_single_param:
1998+
for param in param_list:
1999+
partition_size = param.ds_tensor.ds_numel
2000+
tensor_size = partition_size * self.num_partitions
19982001

1999-
tensor_size = partition_size * self.num_partitions
2000-
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
2001-
partitions = []
2002-
for i in range(self.num_partitions):
2003-
start = partition_size * i
2002+
flat_tensor = torch.empty(tensor_size, dtype=param.ds_tensor.dtype, device=self.local_device)
20042003

2005-
partitions.append(flat_tensor.narrow(0, start, partition_size))
2004+
flat_tensor.requires_grad = False
20062005

2007-
if i == self.get_partition_rank():
2008-
offset = 0
2009-
for param in param_list:
2010-
param_numel = param.ds_tensor.ds_numel
2006+
partitions = []
2007+
for i in range(self.num_partitions):
2008+
start = partition_size * i
2009+
partitions.append(flat_tensor.narrow(0, start, partition_size))
20112010

2012-
partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
2011+
if i == self.get_partition_rank():
2012+
partitioned_tensor.copy_(param.ds_tensor.data)
20132013

2014-
offset += param_numel
2014+
if hasattr(param, 'ds_quant_scale'):
2015+
scale_size = param.ds_tensor.ds_quant_scale.numel()
2016+
scale_tensor_size = scale_size * self.num_partitions
2017+
flat_scale_tensor = torch.empty(scale_tensor_size,
2018+
dtype=param.ds_tensor.ds_quant_scale.dtype,
2019+
device=self.local_device)
2020+
flat_scale_tensor.requires_grad = False
20152021

2016-
if hasattr(param_list[0], 'ds_quant_scale'):
2017-
scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list])
2018-
scale_tensor_size = scale_size * self.world_size
2019-
flat_scale_tensor = torch.empty(scale_tensor_size,
2020-
dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
2021-
device=self.local_device)
2022-
scale_partitions = []
2023-
for i in range(self.world_size):
2024-
start = scale_tensor_size * i
2025-
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size))
2026-
if i == self.rank:
2022+
scale_partitions = []
2023+
for i in range(self.num_partitions):
2024+
start = scale_size * i
2025+
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_size))
2026+
if i == self.get_partition_rank():
2027+
scale_partitions[i].copy_(param.ds_tensor.ds_quant_scale.data)
2028+
2029+
dist.all_gather_into_tensor(flat_tensor,
2030+
partitions[self.get_partition_rank()],
2031+
group=self.get_partition_dp_group(param),
2032+
async_op=False)
2033+
2034+
if hasattr(param, 'ds_quant_scale'):
2035+
dist.all_gather(flat_scale_tensor,
2036+
param.ds_tensor.ds_quant_scale,
2037+
group=self.get_partition_dp_group(param),
2038+
async_op=False)
2039+
2040+
param.data = flat_tensor.narrow(0, 0, param.ds_numel).view(param.ds_shape).data
2041+
2042+
if hasattr(param, 'ds_quant_scale'):
2043+
param.data = self.quantizer_module.dequantize(param.data, flat_scale_tensor)
2044+
else:
2045+
partition_size = sum([param.ds_tensor.ds_numel for param in param_list])
2046+
2047+
tensor_size = partition_size * self.num_partitions
2048+
flat_tensor = torch.empty(tensor_size, dtype=param_list[0].ds_tensor.dtype, device=self.local_device)
2049+
partitions = []
2050+
for i in range(self.num_partitions):
2051+
start = partition_size * i
2052+
2053+
partitions.append(flat_tensor.narrow(0, start, partition_size))
2054+
2055+
if i == self.get_partition_rank():
20272056
offset = 0
20282057
for param in param_list:
2029-
param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
2058+
param_numel = param.ds_tensor.ds_numel
20302059

2031-
scale_partitions[i].narrow(0, offset,
2032-
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
2060+
partitions[i].narrow(0, offset, param_numel).copy_(param.ds_tensor.data)
20332061

2034-
offset += param_scale_numel
2062+
offset += param_numel
20352063

2036-
dist.all_gather_into_tensor(flat_tensor,
2037-
partitions[self.get_partition_rank()],
2038-
group=self.get_partition_dp_group(param),
2039-
async_op=False)
2040-
if hasattr(param_list[0], 'ds_quant_scale'):
2041-
dist.all_gather(flat_scale_tensor,
2042-
param_list[0].ds_quant_scale,
2043-
group=self.get_partition_dp_group(param),
2044-
async_op=False)
2045-
param_offset = 0
2064+
if hasattr(param_list[0], 'ds_quant_scale'):
2065+
scale_size = sum([param.ds_tensor.ds_quant_scale.numel() for param in param_list])
2066+
scale_tensor_size = scale_size * self.world_size
2067+
flat_scale_tensor = torch.empty(scale_tensor_size,
2068+
dtype=param_list[0].ds_tensor.ds_quant_scale.dtype,
2069+
device=self.local_device)
2070+
scale_partitions = []
2071+
for i in range(self.world_size):
2072+
start = scale_tensor_size * i
2073+
scale_partitions.append(flat_scale_tensor.narrow(0, start, scale_tensor_size))
2074+
if i == self.rank:
2075+
offset = 0
2076+
for param in param_list:
2077+
param_scale_numel = param.ds_tensor.ds_quant_scale.ds_numel
2078+
2079+
scale_partitions[i].narrow(0, offset,
2080+
param_scale_numel).copy_(param.ds_tensor.ds_quant_scale.data)
2081+
2082+
offset += param_scale_numel
2083+
2084+
dist.all_gather_into_tensor(flat_tensor,
2085+
partitions[self.get_partition_rank()],
2086+
group=self.get_partition_dp_group(param),
2087+
async_op=False)
2088+
if hasattr(param_list[0], 'ds_quant_scale'):
2089+
dist.all_gather(flat_scale_tensor,
2090+
param_list[0].ds_quant_scale,
2091+
group=self.get_partition_dp_group(param),
2092+
async_op=False)
2093+
param_offset = 0
20462094

2047-
for param in param_list:
2048-
param_partition_size = param.ds_tensor.ds_numel
2049-
param_size = param.ds_numel
2050-
replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device)
2095+
for param in param_list:
2096+
param_partition_size = param.ds_tensor.ds_numel
2097+
param_size = param.ds_numel
2098+
replicated_tensor = torch.empty(param.ds_shape, dtype=param.ds_tensor.dtype, device=self.local_device)
20512099

2052-
for i in range(self.num_partitions):
2100+
for i in range(self.num_partitions):
20532101

2054-
start = i * partition_size
2102+
start = i * partition_size
20552103

2056-
param_start = i * param_partition_size
2104+
param_start = i * param_partition_size
20572105

2058-
if param_start < param_size:
2059-
numel_to_copy = min(param_size - param_start, param_partition_size)
2106+
if param_start < param_size:
2107+
numel_to_copy = min(param_size - param_start, param_partition_size)
20602108

2061-
part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
2109+
part_to_copy = partitions[i].narrow(0, param_offset, numel_to_copy)
20622110

2063-
replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
2064-
#param_offset += param.data.numel()
2065-
param_offset += param.ds_tensor.ds_numel
2066-
if hasattr(param_list[0], 'ds_quant_scale'):
2067-
replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor)
2068-
param.data = replicated_tensor.data
2111+
replicated_tensor.view(-1).narrow(0, param_start, numel_to_copy).copy_(part_to_copy)
2112+
#param_offset += param.data.numel()
2113+
param_offset += param.ds_tensor.ds_numel
2114+
if hasattr(param_list[0], 'ds_quant_scale'):
2115+
replicated_tensor = self.quantizer_module.dequantize(replicated_tensor, flat_scale_tensor)
2116+
param.data = replicated_tensor.data
20692117

20702118
return None
20712119

0 commit comments

Comments
 (0)