File tree Expand file tree Collapse file tree 2 files changed +12
-5
lines changed Expand file tree Collapse file tree 2 files changed +12
-5
lines changed Original file line number Diff line number Diff line change @@ -2710,9 +2710,9 @@ def _get_current_num_elm(
27102710 def _quantize_block (
27112711 self ,
27122712 block : torch .nn .Module ,
2713- input_ids : list [torch .Tensor ],
2713+ input_ids : Union [ list [torch .Tensor ], dict ],
27142714 input_others : dict ,
2715- q_input : Union [None , torch .Tensor ] = None ,
2715+ q_input : Union [torch .Tensor , dict , None ] = None ,
27162716 device : Union [str , torch .device ] = "cpu" ,
27172717 ):
27182718 """Quantize the weights of a given block of the model.
@@ -2825,7 +2825,11 @@ def _quantize_block(
28252825 else :
28262826 lr_schedule = copy .deepcopy (self .lr_scheduler )
28272827
2828- nsamples = len (input_ids )
2828+ if isinstance (input_ids , dict ): # input_ids of Flux is dict
2829+ nsamples = len (input_ids ["hidden_states" ])
2830+ else :
2831+ nsamples = len (input_ids )
2832+
28292833 pick_samples = self .batch_size * self .gradient_accumulate_steps
28302834 pick_samples = min (nsamples , pick_samples )
28312835 if self .sampler != "rand" :
Original file line number Diff line number Diff line change @@ -210,7 +210,7 @@ def _get_current_q_output(
210210 def _get_block_outputs (
211211 self ,
212212 block : torch .nn .Module ,
213- input_ids : torch .Tensor ,
213+ input_ids : Union [ torch .Tensor , dict ] ,
214214 input_others : torch .Tensor ,
215215 bs : int ,
216216 device : Union [str , torch .device ],
@@ -233,8 +233,11 @@ def _get_block_outputs(
233233 """
234234
235235 output = defaultdict (list )
236- nsamples = len (input_ids )
237236 output_config = output_configs .get (block .__class__ .__name__ , [])
237+ if isinstance (input_ids , dict ):
238+ nsamples = len (input_ids ["hidden_states" ])
239+ else :
240+ nsamples = len (input_ids )
238241
239242 for i in range (0 , nsamples , bs ):
240243 end_index = min (nsamples , i + bs )
You can’t perform that action at this time.
0 commit comments