Skip to content

Commit 2bb9f9b

Browse files
mengniwang95XuehaoSun
authored andcommitted
Fix Flux tuning issue (#936)
Signed-off-by: Mengni Wang <[email protected]>
1 parent 9d140cc commit 2bb9f9b

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

auto_round/compressors/base.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2710,9 +2710,9 @@ def _get_current_num_elm(
27102710
def _quantize_block(
27112711
self,
27122712
block: torch.nn.Module,
2713-
input_ids: list[torch.Tensor],
2713+
input_ids: Union[list[torch.Tensor], dict],
27142714
input_others: dict,
2715-
q_input: Union[None, torch.Tensor] = None,
2715+
q_input: Union[torch.Tensor, dict, None] = None,
27162716
device: Union[str, torch.device] = "cpu",
27172717
):
27182718
"""Quantize the weights of a given block of the model.
@@ -2825,7 +2825,11 @@ def _quantize_block(
28252825
else:
28262826
lr_schedule = copy.deepcopy(self.lr_scheduler)
28272827

2828-
nsamples = len(input_ids)
2828+
if isinstance(input_ids, dict): # input_ids of Flux is dict
2829+
nsamples = len(input_ids["hidden_states"])
2830+
else:
2831+
nsamples = len(input_ids)
2832+
28292833
pick_samples = self.batch_size * self.gradient_accumulate_steps
28302834
pick_samples = min(nsamples, pick_samples)
28312835
if self.sampler != "rand":

auto_round/compressors/diffusion/compressor.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def _get_current_q_output(
210210
def _get_block_outputs(
211211
self,
212212
block: torch.nn.Module,
213-
input_ids: torch.Tensor,
213+
input_ids: Union[torch.Tensor, dict],
214214
input_others: torch.Tensor,
215215
bs: int,
216216
device: Union[str, torch.device],
@@ -233,8 +233,11 @@ def _get_block_outputs(
233233
"""
234234

235235
output = defaultdict(list)
236-
nsamples = len(input_ids)
237236
output_config = output_configs.get(block.__class__.__name__, [])
237+
if isinstance(input_ids, dict):
238+
nsamples = len(input_ids["hidden_states"])
239+
else:
240+
nsamples = len(input_ids)
238241

239242
for i in range(0, nsamples, bs):
240243
end_index = min(nsamples, i + bs)

0 commit comments

Comments
 (0)