Skip to content

Commit 4e1b81f

Browse files
authored
update gguf convert file and fix bug of permute bug (#679)
1 parent 1ead9fe commit 4e1b81f

File tree

6 files changed

+1219
-382
lines changed

6 files changed

+1219
-382
lines changed

auto_round/autoround.py

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -828,10 +828,10 @@ def get_act_max_hook(module, input, output):
828828

829829
if not hasattr(module, "imatrix"):
830830
module.imatrix = squared
831-
module.imatrix_cnt = 1
831+
module.imatrix_cnt = input.shape[0]
832832
else:
833833
module.imatrix += squared
834-
module.imatrix_cnt += 1
834+
module.imatrix_cnt += input.shape[0]
835835

836836
hook_handles = []
837837
for name, module in model.named_modules():
@@ -885,33 +885,36 @@ def get_act_max_hook(module, input, output):
885885
cnt = 1
886886
cnt += 1
887887
except RuntimeError as e:
888-
try:
889-
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
890-
import accelerate
891-
accelerate.hooks.remove_hook_from_submodules(model)
892-
# Fallback: out-of-memory → try CPU blockwise quantization
893-
logger.warning("Out of VRAM, falling back to blockwise quantization. Accuracy may degrade.")
894-
model = model.to("cpu")
895-
clear_memory()
896-
self.quantize_via_rtn_blockwise(all_to_quantized_module_names)
897-
except Exception:
898-
# Final fallback: warn and use CPU-only quantization
899-
logger.warning("Fallback to CPU. "
900-
"Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`.")
901-
model = model.to("cpu")
902-
clear_memory()
903-
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
904-
import accelerate
905-
accelerate.hooks.remove_hook_from_submodules(model)
906-
907-
orig_device = self.device
908-
self.device = "cpu"
909-
self.quantize_via_rtn_blockwise(all_to_quantized_module_names)
910-
self.device = orig_device
911-
finally:
912-
# Always remove hooks
913-
for hook in hooks:
914-
hook.remove()
888+
if "CUDA out of memory" in str(e) or "MODULE:PT_DEVMEM" in str(e):
889+
try:
890+
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
891+
import accelerate
892+
accelerate.hooks.remove_hook_from_submodules(model)
893+
# Fallback: out-of-memory → try CPU blockwise quantization
894+
logger.warning("Out of VRAM, falling back to blockwise quantization. Accuracy may degrade.")
895+
model = model.to("cpu")
896+
clear_memory()
897+
self.quantize_via_rtn_blockwise(all_to_quantized_module_names)
898+
except Exception:
899+
# Final fallback: warn and use CPU-only quantization
900+
logger.warning("Fallback to CPU. "
901+
"Consider enabling `low_gpu_mem_usage` or using more GPUs via `--device 0,1,2,3`.")
902+
model = model.to("cpu")
903+
clear_memory()
904+
if hasattr(model, "hf_device_map") and len(model.hf_device_map) > 1:
905+
import accelerate
906+
accelerate.hooks.remove_hook_from_submodules(model)
907+
908+
orig_device = self.device
909+
self.device = "cpu"
910+
self.quantize_via_rtn_blockwise(all_to_quantized_module_names)
911+
self.device = orig_device
912+
finally:
913+
# Always remove hooks
914+
for hook in hooks:
915+
hook.remove()
916+
else:
917+
raise
915918

916919
# Move back to CPU and free memory
917920
model.to("cpu")
@@ -1029,7 +1032,7 @@ def quantize_layer_via_rtn(self, name: str) -> None:
10291032

10301033
# Step 2: Try quantization on GPU first, fall back to CPU if OOM
10311034
# if only export gguf, using gguf-packing instead of rtn
1032-
if self.is_packing_immediate and self.iters == 0 and "gguf" in self.formats[0] and self.disable_opt_rtn:
1035+
if self.is_packing_immediate and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn:
10331036
m.scale = None
10341037
m.zp = None
10351038
else:

auto_round/data_type/gguf.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
from auto_round.data_type.utils import round_ste, reshape_pad_tensor_by_group_size, revert_tensor_by_pad, logger
1717
from auto_round.data_type.register import register_dtype
18+
from auto_round.utils import get_reciprocal
1819

1920

2021
@register_dtype("int_sym_dq")
@@ -69,7 +70,7 @@ def quant_tensor_sym_dq(
6970

7071
scale = scale.view(-1, 1)
7172
zp = torch.full_like(scale, maxq) # pylint: disable=E1130
72-
int_w = torch.where(scale != 0, round_ste(tensor / scale + v), 0)
73+
int_w = round_ste(tensor * get_reciprocal(scale) + v)
7374
q = torch.clamp(int_w + zp, 0, 2 ** bits - 1)
7475
qdq_result = (scale * (q - zp)).to(tensor.dtype)
7576
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)
@@ -140,7 +141,7 @@ def double_quant_tensor(tensor, bits):
140141
scale = wmax / maxq
141142
scale = scale.view(-1, 1)
142143
# inverse_scale = torch.where(scale == 0, 0, 1 / scale)
143-
inverse_scale = torch.where(wmax > 0, maxq / wmax, 0).view(-1, 1)
144+
inverse_scale = (maxq * get_reciprocal(wmax)).clamp(min=0).view(-1, 1)
144145
qdq_tensor = torch.clamp(round_ste(tensor * inverse_scale), max=maxq) * scale
145146
return qdq_tensor, scale
146147

@@ -151,7 +152,7 @@ def double_quant_tensor_sym(tensor, bits):
151152
imax = abs(tensor).argmax(axis=-1, keepdims=True)
152153
wmax = torch.take_along_dim(tensor, imax, dim=-1)
153154
scale = wmax / -maxq
154-
inverse_scale = torch.where(scale == 0, 0, 1 / scale) ##1e-40
155+
inverse_scale = get_reciprocal(scale)
155156
qdq_tensor = torch.clip((round_ste(tensor * inverse_scale)), -maxq, maxq - 1) * scale
156157
return qdq_tensor, scale
157158

@@ -161,7 +162,7 @@ def make_qp_quants(nmax, data, quant_weights):
161162
quant_weights = quant_weights.to(torch.float32)
162163
group_max = torch.max(data, dim=-1, keepdim=True)[0]
163164
scale = group_max / nmax
164-
iscale = torch.where(scale == 0, 0, 1 / scale)
165+
iscale = get_reciprocal(scale)
165166

166167
L = torch.round(iscale * data)
167168
diffs = data - scale * L
@@ -171,7 +172,7 @@ def make_qp_quants(nmax, data, quant_weights):
171172
if _is == 0:
172173
continue
173174
scale_is = group_max / (0.1 * _is + nmax)
174-
iscale_is = torch.where(scale_is == 0, 0, 1 / scale_is)
175+
iscale_is = get_reciprocal(scale_is)
175176

176177
tmp_L = torch.round(iscale_is * data).clip(max=nmax)
177178
diffs = data - scale_is * tmp_L
@@ -328,11 +329,11 @@ def quant_tensor_gguf_asym_dq(
328329
use_mad=params["use_mad"], weights=quant_weights
329330
)
330331
scale = scale.to(scale_dtype)
331-
scale = torch.where(torch.abs(scale) < 1e-30, 0, scale)
332+
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
332333
scale = scale.reshape(-1, super_group_size)
333334
wmin = wmin.reshape(-1, super_group_size)
334335
scale, d_scale = double_quant_tensor(scale, super_bits)
335-
wmin = torch.where(torch.abs(wmin) < 1e-30, 0, wmin)
336+
wmin = torch.where(torch.abs(wmin) < 1e-30, torch.zeros_like(wmin), wmin)
336337
wmin, d_wmin = double_quant_tensor(wmin, super_bits)
337338
wmin = wmin.view(-1, 1)
338339
scale = scale.view(-1, 1)
@@ -386,7 +387,7 @@ def quant_tensor_gguf_asym_dq(
386387
use_mad=params["use_mad"], weights=quant_weights
387388
)
388389
scale = scale.to(scale_dtype)
389-
scale = torch.where(torch.abs(scale) < 1e-30, 0, scale)
390+
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
390391
nmax = 2 ** super_bits - 1
391392
scale = scale.reshape(-1, super_group_size)
392393
wmin = wmin_0.reshape(-1, super_group_size)
@@ -399,7 +400,7 @@ def quant_tensor_gguf_asym_dq(
399400
d_wmin = d_wmin.unsqueeze(-1)
400401
scale = (d_scale * q_scale).view(-1, 1)
401402
wmin = (d_wmin * q_wmin).view(-1, 1)
402-
inverse_scale = torch.where(scale == 0, 0, 1 / scale)
403+
inverse_scale = get_reciprocal(scale)
403404

404405
int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq)
405406
qdq_result = (scale * int_w - wmin).to(orig_dtype)
@@ -436,7 +437,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
436437

437438
# scale = 1 / ((maxq - minq) / (rmax - rmin + 1e-8))
438439
scale = (rmax - rmin) / (maxq - minq)
439-
iscale = torch.where(scale == 0, 0, 1 / scale)
440+
iscale = get_reciprocal(scale)
440441
# quant_data = torch.clamp(torch.round((maxq - minq) / (rmax - rmin + 1e-8) * (data - rmin)), minq, maxq)
441442
quant_data = torch.clamp(torch.round(iscale * (data - rmin)), minq, maxq)
442443
diff = scale * quant_data + rmin - data
@@ -447,7 +448,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
447448
factor = rrmin + rdelta * is_ + maxq - minq
448449
# iscale_new = factor / (rmax - rmin + 1e-8)
449450
scale_new = (rmax - rmin) / factor
450-
iscale_new = torch.where(scale_new == 0, 0, 1 / scale_new)
451+
iscale_new = get_reciprocal(scale_new)
451452
quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin)), minq, maxq)
452453

453454
mul_weights_quant_data = weights * quant_data_new
@@ -460,7 +461,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u
460461
this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D
461462
this_min[this_min > 0] = 0
462463
this_scale[this_min > 0] = (sum_xl / sum_l2)[this_min > 0]
463-
reverse_this_scale = torch.where(this_scale == 0, 0, 1 / this_scale)
464+
reverse_this_scale = get_reciprocal(this_scale)
464465

465466
quant_data = torch.clamp(torch.round(reverse_this_scale * (data - this_min)), minq, maxq)
466467
diff = this_scale * quant_data + this_min - data
@@ -569,13 +570,13 @@ def quant_tensor_gguf_sym_dq(
569570
quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :]
570571

571572
scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights)
572-
scale = torch.where(torch.abs(scale) < 1e-30, 0, scale)
573+
scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale)
573574
# conduct double quant
574575
scale, d_scale = double_quant_tensor_sym(scale, super_bits)
575576

576577
scale = scale.unsqueeze(-1)
577578
zp = torch.full_like(scale, maxq) # pylint: disable=E1130
578-
inverse_scale = torch.where(scale == 0, 0, 1.0 / scale)
579+
inverse_scale = get_reciprocal(scale)
579580
int_w = torch.round(tensor * inverse_scale).clip(-maxq, maxq - 1) + maxq
580581
qdq_result = (scale * (int_w - zp)).to(orig_dtype)
581582
qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len)

0 commit comments

Comments
 (0)