@@ -131,18 +131,26 @@ def slice_fn(weight_or_paramter, output_dim, start, end, step=1):
131131def process_weight_transpose (layer , weight_name ):
132132 weight = getattr (layer , weight_name )
133133 if len (weight .shape ) == 2 :
134- weight_transpose = weight .transpose ([ 1 , 0 ])
134+ weight_shape = weight .shape [:: - 1 ]
135135 elif len (weight .shape ) == 3 :
136- weight_transpose = weight .transpose ([0 , 2 , 1 ])
137-
136+ weight_shape = [weight .shape [0 ]] + list (weight .shape [1 :][::- 1 ])
138137 weight_tmp = layer .create_parameter (
139- shape = weight_transpose . shape ,
140- dtype = weight_transpose .dtype ,
138+ shape = weight_shape ,
139+ dtype = weight .dtype ,
141140 default_initializer = paddle .nn .initializer .Constant (0 ),
142141 is_bias = False ,
143142 )
143+ if layer .fd_config .load_config .dynamic_load_weight or layer .fd_config .model_config .enable_cache :
144+ free_tensor (weight , clear_memory = False )
145+ setattr (layer , weight_name , weight_tmp )
146+ return
147+
148+ if len (weight .shape ) == 2 :
149+ weight_transpose = weight .transpose ([1 , 0 ])
150+ elif len (weight .shape ) == 3 :
151+ weight_transpose = weight .transpose ([0 , 2 , 1 ])
144152 weight_tmp .copy_ (weight_transpose , False )
145- free_tensor (weight )
153+ free_tensor (weight , clear_memory = False )
146154 setattr (layer , weight_name , weight_tmp )
147155
148156
@@ -163,9 +171,16 @@ def fn(model_sublayer_name: str, param=None):
163171 model_sublayer = sublayers_dict [model_sublayer_name ]
164172 if isinstance (model_sublayer , KVBatchLinear ):
165173 model_sublayer .process_weights_after_loading ()
174+ if fd_config .quant_config and not fd_config .quant_config .is_checkpoint_bf16 :
175+ # skip for offline quantization
176+ return
166177 if hasattr (model_sublayer , "quant_method" ):
167178 quant_method = getattr (model_sublayer , "quant_method" , None )
168- unquant_moe_cls = type (get_moe_method ())
179+ unquant_moe_layer = get_moe_method ()
180+ if unquant_moe_layer is None :
181+ unquant_moe_cls = object
182+ else :
183+ unquant_moe_cls = type (unquant_moe_layer )
169184 if type (quant_method ) is UnquantizedLinearMethod or type (quant_method ) is unquant_moe_cls :
170185 # skip unquantized linear
171186 return
@@ -225,26 +240,33 @@ def process_final_after_loading(model, fd_config: FDConfig):
225240 from fastdeploy .model_executor .layers .moe .moe import get_moe_method
226241
227242 for name , sublayer in model .named_sublayers ():
243+ if isinstance (sublayer , KVBatchLinear ):
244+ continue
228245 quant_method = getattr (sublayer , "quant_method" , None )
229246 if quant_method is not None :
230- unquant_moe_cls = type (get_moe_method ())
231- if not (type (quant_method ) is UnquantizedLinearMethod or type (quant_method ) is unquant_moe_cls ):
247+ unquant_moe_layer = get_moe_method ()
248+ if unquant_moe_layer is None :
249+ unquant_moe_cls = object
250+ else :
251+ unquant_moe_cls = type (unquant_moe_layer )
252+ is_unquant_cls = type (quant_method ) is UnquantizedLinearMethod or type (quant_method ) is unquant_moe_cls
253+ is_offline_quantized_ckpt = not (fd_config .quant_config and fd_config .quant_config .is_checkpoint_bf16 )
254+ if is_unquant_cls or is_offline_quantized_ckpt :
255+ if hasattr (quant_method , "process_weights_after_loading" ):
256+ quant_method .process_weights_after_loading (sublayer )
232257 continue
233- if hasattr (quant_method , "process_weights_after_loading" ):
234- quant_method .process_weights_after_loading (sublayer )
235- if isinstance (sublayer , KVBatchLinear ):
236- continue
237258 if not hasattr (sublayer , "process_weights_after_loading" ):
238259 continue
239- # Only for specific layers, such as lmhead
240260 sublayer .process_weights_after_loading ()
241261
242262
243- def free_tensor (tensor ):
263+ def free_tensor (tensor , clear_memory = True ):
244264 if hasattr (tensor , "tensor_track" ):
245265 tensor .tensor_track = None
246266 tensor .value ().get_tensor ()._clear ()
247267 del tensor
268+ if clear_memory :
269+ paddle .device .cuda .empty_cache ()
248270
249271
250272def fd_cast (weight , param ):
0 commit comments