@@ -1994,78 +1994,126 @@ def _allgather_params(self, param_list, hierarchy=0):
19941994 if len (param_list ) == 0 :
19951995 return
19961996
1997- partition_size = sum ([param .ds_tensor .ds_numel for param in param_list ])
1997+ if self .allgather_single_param :
1998+ for param in param_list :
1999+ partition_size = param .ds_tensor .ds_numel
2000+ tensor_size = partition_size * self .num_partitions
19982001
1999- tensor_size = partition_size * self .num_partitions
2000- flat_tensor = torch .empty (tensor_size , dtype = param_list [0 ].ds_tensor .dtype , device = self .local_device )
2001- partitions = []
2002- for i in range (self .num_partitions ):
2003- start = partition_size * i
2002+ flat_tensor = torch .empty (tensor_size , dtype = param .ds_tensor .dtype , device = self .local_device )
20042003
2005- partitions . append ( flat_tensor .narrow ( 0 , start , partition_size ))
2004+ flat_tensor .requires_grad = False
20062005
2007- if i == self . get_partition_rank ():
2008- offset = 0
2009- for param in param_list :
2010- param_numel = param . ds_tensor . ds_numel
2006+ partitions = []
2007+ for i in range ( self . num_partitions ):
2008+ start = partition_size * i
2009+ partitions . append ( flat_tensor . narrow ( 0 , start , partition_size ))
20112010
2012- partitions [i ].narrow (0 , offset , param_numel ).copy_ (param .ds_tensor .data )
2011+ if i == self .get_partition_rank ():
2012+ partitioned_tensor .copy_ (param .ds_tensor .data )
20132013
2014- offset += param_numel
2014+ if hasattr (param , 'ds_quant_scale' ):
2015+ scale_size = param .ds_tensor .ds_quant_scale .numel ()
2016+ scale_tensor_size = scale_size * self .num_partitions
2017+ flat_scale_tensor = torch .empty (scale_tensor_size ,
2018+ dtype = param .ds_tensor .ds_quant_scale .dtype ,
2019+ device = self .local_device )
2020+ flat_scale_tensor .requires_grad = False
20152021
2016- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2017- scale_size = sum ([param .ds_tensor .ds_quant_scale .numel () for param in param_list ])
2018- scale_tensor_size = scale_size * self .world_size
2019- flat_scale_tensor = torch .empty (scale_tensor_size ,
2020- dtype = param_list [0 ].ds_tensor .ds_quant_scale .dtype ,
2021- device = self .local_device )
2022- scale_partitions = []
2023- for i in range (self .world_size ):
2024- start = scale_tensor_size * i
2025- scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_tensor_size ))
2026- if i == self .rank :
2022+ scale_partitions = []
2023+ for i in range (self .num_partitions ):
2024+ start = scale_size * i
2025+ scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_size ))
2026+ if i == self .get_partition_rank ():
2027+ scale_partitions [i ].copy_ (param .ds_tensor .ds_quant_scale .data )
2028+
2029+ dist .all_gather_into_tensor (flat_tensor ,
2030+ partitions [self .get_partition_rank ()],
2031+ group = self .get_partition_dp_group (param ),
2032+ async_op = False )
2033+
2034+ if hasattr (param , 'ds_quant_scale' ):
2035+ dist .all_gather (flat_scale_tensor ,
2036+ param .ds_tensor .ds_quant_scale ,
2037+ group = self .get_partition_dp_group (param ),
2038+ async_op = False )
2039+
2040+ param .data = flat_tensor .narrow (0 , 0 , param .ds_numel ).view (param .ds_shape ).data
2041+
2042+ if hasattr (param , 'ds_quant_scale' ):
2043+ param .data = self .quantizer_module .dequantize (param .data , flat_scale_tensor )
2044+ else :
2045+ partition_size = sum ([param .ds_tensor .ds_numel for param in param_list ])
2046+
2047+ tensor_size = partition_size * self .num_partitions
2048+ flat_tensor = torch .empty (tensor_size , dtype = param_list [0 ].ds_tensor .dtype , device = self .local_device )
2049+ partitions = []
2050+ for i in range (self .num_partitions ):
2051+ start = partition_size * i
2052+
2053+ partitions .append (flat_tensor .narrow (0 , start , partition_size ))
2054+
2055+ if i == self .get_partition_rank ():
20272056 offset = 0
20282057 for param in param_list :
2029- param_scale_numel = param .ds_tensor . ds_quant_scale .ds_numel
2058+ param_numel = param .ds_tensor .ds_numel
20302059
2031- scale_partitions [i ].narrow (0 , offset ,
2032- param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
2060+ partitions [i ].narrow (0 , offset , param_numel ).copy_ (param .ds_tensor .data )
20332061
2034- offset += param_scale_numel
2062+ offset += param_numel
20352063
2036- dist .all_gather_into_tensor (flat_tensor ,
2037- partitions [self .get_partition_rank ()],
2038- group = self .get_partition_dp_group (param ),
2039- async_op = False )
2040- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2041- dist .all_gather (flat_scale_tensor ,
2042- param_list [0 ].ds_quant_scale ,
2043- group = self .get_partition_dp_group (param ),
2044- async_op = False )
2045- param_offset = 0
2064+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2065+ scale_size = sum ([param .ds_tensor .ds_quant_scale .numel () for param in param_list ])
2066+ scale_tensor_size = scale_size * self .world_size
2067+ flat_scale_tensor = torch .empty (scale_tensor_size ,
2068+ dtype = param_list [0 ].ds_tensor .ds_quant_scale .dtype ,
2069+ device = self .local_device )
2070+ scale_partitions = []
2071+ for i in range (self .world_size ):
2072+ start = scale_tensor_size * i
2073+ scale_partitions .append (flat_scale_tensor .narrow (0 , start , scale_tensor_size ))
2074+ if i == self .rank :
2075+ offset = 0
2076+ for param in param_list :
2077+ param_scale_numel = param .ds_tensor .ds_quant_scale .ds_numel
2078+
2079+ scale_partitions [i ].narrow (0 , offset ,
2080+ param_scale_numel ).copy_ (param .ds_tensor .ds_quant_scale .data )
2081+
2082+ offset += param_scale_numel
2083+
2084+ dist .all_gather_into_tensor (flat_tensor ,
2085+ partitions [self .get_partition_rank ()],
2086+ group = self .get_partition_dp_group (param ),
2087+ async_op = False )
2088+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2089+ dist .all_gather (flat_scale_tensor ,
2090+ param_list [0 ].ds_quant_scale ,
2091+ group = self .get_partition_dp_group (param ),
2092+ async_op = False )
2093+ param_offset = 0
20462094
2047- for param in param_list :
2048- param_partition_size = param .ds_tensor .ds_numel
2049- param_size = param .ds_numel
2050- replicated_tensor = torch .empty (param .ds_shape , dtype = param .ds_tensor .dtype , device = self .local_device )
2095+ for param in param_list :
2096+ param_partition_size = param .ds_tensor .ds_numel
2097+ param_size = param .ds_numel
2098+ replicated_tensor = torch .empty (param .ds_shape , dtype = param .ds_tensor .dtype , device = self .local_device )
20512099
2052- for i in range (self .num_partitions ):
2100+ for i in range (self .num_partitions ):
20532101
2054- start = i * partition_size
2102+ start = i * partition_size
20552103
2056- param_start = i * param_partition_size
2104+ param_start = i * param_partition_size
20572105
2058- if param_start < param_size :
2059- numel_to_copy = min (param_size - param_start , param_partition_size )
2106+ if param_start < param_size :
2107+ numel_to_copy = min (param_size - param_start , param_partition_size )
20602108
2061- part_to_copy = partitions [i ].narrow (0 , param_offset , numel_to_copy )
2109+ part_to_copy = partitions [i ].narrow (0 , param_offset , numel_to_copy )
20622110
2063- replicated_tensor .view (- 1 ).narrow (0 , param_start , numel_to_copy ).copy_ (part_to_copy )
2064- #param_offset += param.data.numel()
2065- param_offset += param .ds_tensor .ds_numel
2066- if hasattr (param_list [0 ], 'ds_quant_scale' ):
2067- replicated_tensor = self .quantizer_module .dequantize (replicated_tensor , flat_scale_tensor )
2068- param .data = replicated_tensor .data
2111+ replicated_tensor .view (- 1 ).narrow (0 , param_start , numel_to_copy ).copy_ (part_to_copy )
2112+ #param_offset += param.data.numel()
2113+ param_offset += param .ds_tensor .ds_numel
2114+ if hasattr (param_list [0 ], 'ds_quant_scale' ):
2115+ replicated_tensor = self .quantizer_module .dequantize (replicated_tensor , flat_scale_tensor )
2116+ param .data = replicated_tensor .data
20692117
20702118 return None
20712119
0 commit comments