From a34493d9d9c27631c13d926fd92ed362f5678439 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Wed, 11 Dec 2024 23:51:24 +0800 Subject: [PATCH 1/6] temp save --- .../transformers/npu_models/convert.py | 15 +++- .../transformers/npu_models/quantize.py | 86 +++++++++++++++++++ 2 files changed, 99 insertions(+), 2 deletions(-) create mode 100644 python/llm/src/ipex_llm/transformers/npu_models/quantize.py diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index f6589efa411..065b48bd349 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -97,12 +97,23 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, if (layer.in_features == 18944 and layer.out_features == 3584): qtype = "sym_int8_rtn" iqtype = ggml_tensor_qtype[qtype] - enable_scale_search = (os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" or - os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0") + if qtype == "asym_int4_rtn": + enable_scale_search = (os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" or + os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0") + elif qtype == "sym_int4_rtn": + enable_scale_search = os.environ.get("IPEX_LLM_NPU_QUANTIZATION_OPT", "0") != "0" + else: + enable_scale_search = False qweights, scale = ggml_convert_qtype(layer.weight.data.to(torch.float32), iqtype, device=device, enable_scale_search=enable_scale_search, imatrix=imatrix) + if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": + from .quantize import scale_grid_search + # scale grid search + qweights, scale = scale_grid_search(layer.weight.data.to(torch.float32), + scale.to(torch.float32), + qweights) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py new file mode 100644 index 00000000000..303bdcc0eba --- /dev/null +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -0,0 +1,86 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Some parts of this file is adapted from +# https://github.com/mobiusml/hqq/blob/master/hqq/core/optimize.py +# which is licensed under Apache License 2.0: +# +# Copyright 2021 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from torch import float32, float16, Tensor +from functools import partial +from typing import Union + + +def update_scale_grid_search(x: Tensor, scale: Tensor, min_max: list, N: int = 128 + 1): + print(x.shape) + print(scale.shape) + + assert N % 2 == 1, "Please check whether N: odd number" + rng_dump = 0.05 # 0.05 / 1. + z_val = 2e-4 + + device = scale.device + dtype = scale.dtype + ############################### + print("init scale shape is : ", scale.shape) + W_q = (x / scale).clamp(min_max[0], min_max[1]) + n_clusters = W_q.shape[0] + rng = torch.abs(scale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump + print("rng is : ", rng) + + scale_shifted = ( + torch.linspace(-rng, rng, N)[None, :] + .to(dtype=dtype, device=device) + .repeat(n_clusters, 1) + ) + + scale_shifted += scale + + # Safe inverse + scale_shifted[ + torch.logical_and(scale_shifted >= 0, torch.abs(scale_shifted) <= z_val) + ] = z_val + scale_shifted[ + torch.logical_and(scale_shifted < 0, torch.abs(scale_shifted) <= z_val) + ] = -z_val + + err = torch.empty([n_clusters, N], dtype=dtype, device=device) + for i in range(N): + W_r = W_q * scale_shifted[:, i][:, None] + err[:, i] = torch.abs(x - W_r).mean(axis=1, keepdim=True).squeeze() + print(f"err [{i}] shape is ", err[i].shape) + + ind_r = torch.argmin(err, axis=1).to(torch.int32) + ind_c = torch.arange(len(ind_r), dtype=torch.int32, device=device) + scale_b = scale_shifted[ind_c, ind_r] + + # obtain qwights based on scale_b + + return scale_b, qweights From 2deb44523634dc916c9a80316dd5b14c9e41d70e Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Thu, 12 Dec 2024 10:55:30 +0800 Subject: [PATCH 2/6] update --- .../transformers/npu_models/convert.py | 10 ++-- .../transformers/npu_models/linear.py | 2 +- .../transformers/npu_models/quantize.py | 51 +++++++++++-------- 3 files changed, 38 insertions(+), 25 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 065b48bd349..acee0d860fd 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -109,11 +109,13 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, enable_scale_search=enable_scale_search, imatrix=imatrix) if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": - from .quantize import scale_grid_search + from .quantize import update_scale_grid_search # scale grid search - qweights, scale = scale_grid_search(layer.weight.data.to(torch.float32), - scale.to(torch.float32), - qweights) + print("=====original: ", qweights.shape, scale.shape) + qweights, scale = update_scale_grid_search(layer.weight.data.to(torch.float32), + (1.0 / scale.to(torch.float32)), + [-8, 7]) + print("=====update: ", qweights.shape, scale.shape) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/linear.py b/python/llm/src/ipex_llm/transformers/npu_models/linear.py index c8a5dd467ae..a626d9ec042 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/linear.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/linear.py @@ -155,7 +155,7 @@ def __init__( False, ( f"Quantized weight must be in torch.(u)int8" - " dtype instead of {self.weight.dtype}" + f" dtype instead of {self.weight.dtype}" ) ) self.outC, self.inC = self.weight.shape diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index 303bdcc0eba..6fee413c3b9 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -38,49 +38,60 @@ from typing import Union -def update_scale_grid_search(x: Tensor, scale: Tensor, min_max: list, N: int = 128 + 1): +def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): + iscale = iscale.unsqueeze(1) print(x.shape) - print(scale.shape) + print(iscale.shape) assert N % 2 == 1, "Please check whether N: odd number" rng_dump = 0.05 # 0.05 / 1. z_val = 2e-4 - device = scale.device - dtype = scale.dtype + device = iscale.device + dtype = iscale.dtype ############################### - print("init scale shape is : ", scale.shape) - W_q = (x / scale).clamp(min_max[0], min_max[1]) + print("init scale shape is : ", iscale.shape) + W_q = (x * iscale).clamp(min_max[0], min_max[1]) n_clusters = W_q.shape[0] - rng = torch.abs(scale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump + rng = torch.abs(iscale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump print("rng is : ", rng) - scale_shifted = ( + iscale_shifted = ( torch.linspace(-rng, rng, N)[None, :] .to(dtype=dtype, device=device) .repeat(n_clusters, 1) - ) + ) + iscale - scale_shifted += scale + print(iscale_shifted.shape) # Safe inverse - scale_shifted[ - torch.logical_and(scale_shifted >= 0, torch.abs(scale_shifted) <= z_val) + iscale_shifted[ + torch.logical_and(iscale_shifted >= 0, torch.abs(iscale_shifted) <= z_val) ] = z_val - scale_shifted[ - torch.logical_and(scale_shifted < 0, torch.abs(scale_shifted) <= z_val) + iscale_shifted[ + torch.logical_and(iscale_shifted < 0, torch.abs(iscale_shifted) <= z_val) ] = -z_val err = torch.empty([n_clusters, N], dtype=dtype, device=device) for i in range(N): - W_r = W_q * scale_shifted[:, i][:, None] + W_r = W_q * iscale_shifted[:, i][:, None] err[:, i] = torch.abs(x - W_r).mean(axis=1, keepdim=True).squeeze() - print(f"err [{i}] shape is ", err[i].shape) - + ind_r = torch.argmin(err, axis=1).to(torch.int32) ind_c = torch.arange(len(ind_r), dtype=torch.int32, device=device) - scale_b = scale_shifted[ind_c, ind_r] - + iscale_b = iscale_shifted[ind_c, ind_r] + scale_b = 1.0 / iscale_b + iscale_b = iscale_b.unsqueeze(1) + print(iscale_b.shape) # obtain qwights based on scale_b + qweights = (x * iscale_b).to(torch.int8) # m * n + qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 + print(qweights.split(1, dim=-1)) + high_bit, low_bit = qweights.split(1, dim=-1) + print(high_bit.shape) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + qweights = high_bit | low_bit - return scale_b, qweights + return qweights, scale_b.to(torch.float16) From b7d7268091bfc584e4bacff2b489058a64c8d3d6 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Thu, 12 Dec 2024 15:17:14 +0800 Subject: [PATCH 3/6] support hqq scale search of q4_0 --- .../transformers/npu_models/convert.py | 2 -- .../transformers/npu_models/quantize.py | 27 ++++++++----------- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index acee0d860fd..4fcd3f901ad 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -111,11 +111,9 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": from .quantize import update_scale_grid_search # scale grid search - print("=====original: ", qweights.shape, scale.shape) qweights, scale = update_scale_grid_search(layer.weight.data.to(torch.float32), (1.0 / scale.to(torch.float32)), [-8, 7]) - print("=====update: ", qweights.shape, scale.shape) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index 6fee413c3b9..f9a127e3c9f 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -32,16 +32,11 @@ # limitations under the License. import torch -import numpy as np -from torch import float32, float16, Tensor -from functools import partial -from typing import Union +from torch import Tensor def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): iscale = iscale.unsqueeze(1) - print(x.shape) - print(iscale.shape) assert N % 2 == 1, "Please check whether N: odd number" rng_dump = 0.05 # 0.05 / 1. @@ -50,11 +45,9 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = device = iscale.device dtype = iscale.dtype ############################### - print("init scale shape is : ", iscale.shape) - W_q = (x * iscale).clamp(min_max[0], min_max[1]) + W_q = torch.round(x * iscale).clamp(min_max[0], min_max[1]) n_clusters = W_q.shape[0] rng = torch.abs(iscale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump - print("rng is : ", rng) iscale_shifted = ( torch.linspace(-rng, rng, N)[None, :] @@ -74,7 +67,7 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = err = torch.empty([n_clusters, N], dtype=dtype, device=device) for i in range(N): - W_r = W_q * iscale_shifted[:, i][:, None] + W_r = W_q * iscale_shifted[:, i][:, None] err[:, i] = torch.abs(x - W_r).mean(axis=1, keepdim=True).squeeze() ind_r = torch.argmin(err, axis=1).to(torch.int32) @@ -82,16 +75,18 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = iscale_b = iscale_shifted[ind_c, ind_r] scale_b = 1.0 / iscale_b iscale_b = iscale_b.unsqueeze(1) - print(iscale_b.shape) + # obtain qwights based on scale_b - qweights = (x * iscale_b).to(torch.int8) # m * n + qweights = (torch.round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + # test with original + # scale_b = (1.0 / iscale).squeeze() + # qweights = (torch.round(x * iscale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 - print(qweights.split(1, dim=-1)) - high_bit, low_bit = qweights.split(1, dim=-1) - print(high_bit.shape) + low_bit, high_bit = qweights.split(1, dim=-1) high_bit = high_bit.squeeze().view(torch.int8) low_bit = low_bit.squeeze().view(torch.int8) high_bit = high_bit << 4 + low_bit = low_bit & 0x0f qweights = high_bit | low_bit - return qweights, scale_b.to(torch.float16) + return qweights.view(torch.uint8), scale_b.to(torch.float16) From b3025a02504834f9653e41a7acdea9940410c083 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Thu, 12 Dec 2024 17:18:55 +0800 Subject: [PATCH 4/6] update, v1 scale search --- .../ipex_llm/transformers/npu_models/quantize.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index f9a127e3c9f..c47df92d340 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -35,6 +35,9 @@ from torch import Tensor +def c_round(x: Tensor): + return torch.sign(x) * torch.floor(torch.abs(x) + 0.5) + def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): iscale = iscale.unsqueeze(1) @@ -45,7 +48,7 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = device = iscale.device dtype = iscale.dtype ############################### - W_q = torch.round(x * iscale).clamp(min_max[0], min_max[1]) + W_q = c_round(x * iscale).clamp(min_max[0], min_max[1]) n_clusters = W_q.shape[0] rng = torch.abs(iscale).mean() * rng_dump if (rng_dump < 1.0) else rng_dump @@ -55,8 +58,6 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = .repeat(n_clusters, 1) ) + iscale - print(iscale_shifted.shape) - # Safe inverse iscale_shifted[ torch.logical_and(iscale_shifted >= 0, torch.abs(iscale_shifted) <= z_val) @@ -76,11 +77,12 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = scale_b = 1.0 / iscale_b iscale_b = iscale_b.unsqueeze(1) - # obtain qwights based on scale_b - qweights = (torch.round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n # test with original # scale_b = (1.0 / iscale).squeeze() - # qweights = (torch.round(x * iscale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + # qweights = (c_round(x * iscale)).clamp(-8.0, 7.0).to(torch.int8) # m * n + + # obtain qwights based on scale_b + qweights = (c_round(x * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 low_bit, high_bit = qweights.split(1, dim=-1) high_bit = high_bit.squeeze().view(torch.int8) From eff700935b3f739c741053d3a7cfb0f050ba5db2 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Fri, 13 Dec 2024 10:56:55 +0800 Subject: [PATCH 5/6] further exp of hqq q4_0 --- .../transformers/npu_models/convert.py | 10 +- .../transformers/npu_models/quantize.py | 110 ++++++++++++++++-- 2 files changed, 106 insertions(+), 14 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 4fcd3f901ad..7cedf31b4ba 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -109,11 +109,11 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, enable_scale_search=enable_scale_search, imatrix=imatrix) if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": - from .quantize import update_scale_grid_search - # scale grid search - qweights, scale = update_scale_grid_search(layer.weight.data.to(torch.float32), - (1.0 / scale.to(torch.float32)), - [-8, 7]) + from .quantize import update_scale_inverse_median + # scale search by hqq + qweights, scale = update_scale_inverse_median(layer.weight.data.to(torch.float32), + (1.0 / scale.to(torch.float32)), + [-8, 7]) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index c47df92d340..68430dbef38 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -33,17 +33,18 @@ import torch from torch import Tensor +import numpy as np def c_round(x: Tensor): return torch.sign(x) * torch.floor(torch.abs(x) + 0.5) + def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = 128 + 1): iscale = iscale.unsqueeze(1) assert N % 2 == 1, "Please check whether N: odd number" rng_dump = 0.05 # 0.05 / 1. - z_val = 2e-4 device = iscale.device dtype = iscale.dtype @@ -58,14 +59,6 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = .repeat(n_clusters, 1) ) + iscale - # Safe inverse - iscale_shifted[ - torch.logical_and(iscale_shifted >= 0, torch.abs(iscale_shifted) <= z_val) - ] = z_val - iscale_shifted[ - torch.logical_and(iscale_shifted < 0, torch.abs(iscale_shifted) <= z_val) - ] = -z_val - err = torch.empty([n_clusters, N], dtype=dtype, device=device) for i in range(N): W_r = W_q * iscale_shifted[:, i][:, None] @@ -92,3 +85,102 @@ def update_scale_grid_search(x: Tensor, iscale: Tensor, min_max: list, N: int = qweights = high_bit | low_bit return qweights.view(torch.uint8), scale_b.to(torch.float16) + + +# Shrinking operator +def shrink_lp_op(x: Tensor, beta: float, lp_norm: float) -> Tensor: + if lp_norm == 1: + return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta) + else: + return torch.sign(x) * torch.nn.functional.relu( + torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1) + ) + + +def update_scale_hqq(x: Tensor, iscale: Tensor, min_max: list): + iscale = iscale.unsqueeze(1) + opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + z_val = 1e-4 + delta = 1e-4 + + best_error = 1e4 + for i in range(iters): + W_q = c_round(x * iscale).clamp(min_max[0], min_max[1]) + W_q_mask = W_q == 0 + W_q[W_q_mask] = delta + W_r = W_q / iscale + W_e = shrink_lp_op(x - W_r, beta, lp_norm) + W_ = (x - W_e).clone() + W_mask = torch.abs(W_) < z_val + W_[W_mask] = z_val + iscale, _ = torch.median(W_q / W_q, axis=1, keepdim=True) + beta *= kappa + + current_error = float(torch.abs(x - W_r).mean()) + if current_error < best_error: + best_error = current_error + else: + break + + scale_b = 1.0 / iscale + qweights = (c_round(x * iscale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 + low_bit, high_bit = qweights.split(1, dim=-1) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + low_bit = low_bit & 0x0f + qweights = high_bit | low_bit + + return qweights.view(torch.uint8), scale_b.to(torch.float16) + + + +# re-estimate the scale based on the inverse median: Only tested with axis==0 +def update_scale_inverse_median( + W_f: Tensor, iscale: Tensor, min_max: list +) -> tuple: + iscale = iscale.unsqueeze(1) + scale_rng = 2e4 + z_val = 1e-4 + + W_q = c_round(W_f * iscale).clamp(min_max[0], min_max[1]) + + # Build scale tensor + W_f_c = W_f.clone() + W_f_c_mask = torch.abs(W_f_c) < z_val + W_f_c[W_f_c_mask] = z_val + + scale_tensor = (W_q).float() / W_f_c.float() + + # Normalize scale_tensor + scale_b = torch.median(scale_tensor, axis=1, keepdim=True)[0] + scale_b = scale_b.clamp(min=-scale_rng, max=scale_rng) + + # Mix with older scale + W_r = (W_q) / scale_b + err_b = torch.abs(W_f - W_r).mean(axis=1, keepdim=True) + + W_r = (W_q) / iscale + err_a = torch.abs(W_f - W_r).mean(axis=1, keepdim=True) + + mask = (err_b < err_a) + iscale_b = mask * scale_b + (~mask) * iscale + + scale_b = 1.0 / iscale_b + qweights = (c_round(W_f * iscale_b)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + qweights = qweights.reshape(W_f.shape[0], -1 , 2) # m * n/2 * 2 + low_bit, high_bit = qweights.split(1, dim=-1) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + low_bit = low_bit & 0x0f + qweights = high_bit | low_bit + + return qweights.view(torch.uint8), scale_b.to(torch.float16) From 77293243ecf97cd3d466d8e967681770356fe9f5 Mon Sep 17 00:00:00 2001 From: rnwang04 Date: Mon, 16 Dec 2024 18:15:05 +0800 Subject: [PATCH 6/6] update hqq q4_0 --- .../transformers/npu_models/convert.py | 10 ++-- .../transformers/npu_models/quantize.py | 47 ++++++++++++++++++- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/npu_models/convert.py b/python/llm/src/ipex_llm/transformers/npu_models/convert.py index 7cedf31b4ba..7ebc0386ff5 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/convert.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/convert.py @@ -109,11 +109,13 @@ def replace_with_QuantizedLinear(layer, qtype, device, modules_to_not_convert, enable_scale_search=enable_scale_search, imatrix=imatrix) if qtype == "sym_int4_rtn" and os.environ.get("IPEX_LLM_NPU_QUANTIZATION_HQQ", "0") != "0": - from .quantize import update_scale_inverse_median + from .quantize import update_scale_hqq_v2 # scale search by hqq - qweights, scale = update_scale_inverse_median(layer.weight.data.to(torch.float32), - (1.0 / scale.to(torch.float32)), - [-8, 7]) + print("====original scale is :", scale) + qweights, scale = update_scale_hqq_v2(layer.weight.data.to(torch.float32), + scale.to(torch.float32), + [-8, 7]) + print("====updated scale is :", scale) zero = None # split scale to scale & zero if qtype == "asym_int4_rtn": diff --git a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py index 68430dbef38..910df55f9d8 100644 --- a/python/llm/src/ipex_llm/transformers/npu_models/quantize.py +++ b/python/llm/src/ipex_llm/transformers/npu_models/quantize.py @@ -119,10 +119,12 @@ def update_scale_hqq(x: Tensor, iscale: Tensor, min_max: list): W_ = (x - W_e).clone() W_mask = torch.abs(W_) < z_val W_[W_mask] = z_val - iscale, _ = torch.median(W_q / W_q, axis=1, keepdim=True) + iscale, _ = torch.median(W_q / W_, axis=1, keepdim=True) beta *= kappa current_error = float(torch.abs(x - W_r).mean()) + print(i, current_error) + print(iscale, torch.isinf(iscale).any(), torch.isnan(iscale).any()) if current_error < best_error: best_error = current_error else: @@ -141,6 +143,49 @@ def update_scale_hqq(x: Tensor, iscale: Tensor, min_max: list): return qweights.view(torch.uint8), scale_b.to(torch.float16) +def update_scale_hqq_v2(x: Tensor, scale: Tensor, min_max: list): + scale = scale.unsqueeze(1) + opt_params: dict = {"lp_norm": 0.7, "beta": 1e1, "kappa": 1.01, "iters": 20} + lp_norm, beta, kappa, iters = ( + opt_params["lp_norm"], + opt_params["beta"], + opt_params["kappa"], + opt_params["iters"], + ) + + best_error = 1e4 + for i in range(iters): + W_q = c_round(x / scale).clamp(min_max[0], min_max[1]) + W_q_mask = W_q != 0 # m, n + sum_row = torch.sum(W_q_mask.int(), axis=1, keepdim=True) # m, 1 + W_r = W_q * scale + W_e = shrink_lp_op(x - W_r, beta, lp_norm) + W_ = (x - W_e).clone() + tmp = W_ / W_q + tmp[W_q == 0] = 0 + tmp = torch.sum(tmp, axis=1, keepdim=True) # m, 1 + scale = tmp / sum_row # m, 1 + beta *= kappa + + current_error = float(torch.abs(x - W_r).mean()) + print(i, current_error) + if current_error < best_error: + best_error = current_error + else: + break + + scale_b = scale + qweights = (c_round(x / scale)).clamp(min_max[0], min_max[1]).to(torch.int8) # m * n + qweights = qweights.reshape(x.shape[0], -1 , 2) # m * n/2 * 2 + low_bit, high_bit = qweights.split(1, dim=-1) + high_bit = high_bit.squeeze().view(torch.int8) + low_bit = low_bit.squeeze().view(torch.int8) + high_bit = high_bit << 4 + low_bit = low_bit & 0x0f + qweights = high_bit | low_bit + + return qweights.view(torch.uint8), scale_b.to(torch.float16) + # re-estimate the scale based on the inverse median: Only tested with axis==0 def update_scale_inverse_median(