diff --git a/docs/en/llm/README.md b/docs/en/llm/README.md index 4ea245f644e3..5f1d5ca5cc21 100644 --- a/docs/en/llm/README.md +++ b/docs/en/llm/README.md @@ -314,6 +314,7 @@ Large model quantization reduces 16-bit and 32-bit floating-point model paramete - **PTQ**. The self-developed adaptive LLM.PTQ quantization algorithm by the PaddleSlim team builds upon [SmoothQuant](https://arxiv.org/abs/2211.10438) and [Outlier Suppression+](https://arxiv.org/abs/2304.09145), adding the PieceWiseSearch parameter search algorithm to adjust model weight and activation distributions, reducing subsequent A8W8 PTQ quantization loss. - **GPTQ**. [GPTQ](https://arxiv.org/abs/2210.17323) is a mainstream weight quantization algorithm that enables lossless 4-bit integer quantization of large model weights to improve inference speed. +- **FOEM**. [FOEM](https://arxiv.org/abs/2507.11017) is a PTQ method that explicitly incorporates first-order gradient terms to improve quantization error compensation.
llm @@ -339,6 +340,9 @@ python run_quantization.py ./config/llama/ptq_argument.json # GPTQ Quantization Command Reference python run_quantization.py ./config/llama/gptq_argument.json +# FOEM Quantization Command Reference +python run_quantization.py ./config/llama/foem_argument.json + # W8A8C8(INT) Quantization Command Reference python run_quantization.py ./config/llama/ptq_c8_argument.json diff --git a/llm/README.md b/llm/README.md index 7bf792830a71..9a72a8430079 100644 --- a/llm/README.md +++ b/llm/README.md @@ -295,6 +295,7 @@ python -u -m paddle.distributed.launch --devices "0,1,2,3,4,5,6,7" ./llm/tools/ - **PTQ**。PaddleSlim 团队自研的自适应 LLM.PTQ 量化算法,在[SmoothQuant](https://arxiv.org/abs/2211.10438)和[Outlier Suppression+](https://arxiv.org/abs/2304.09145)基础上新增 PieceWiseSearch 参数搜索算法,对模型权重和激活分布进行调整,减少后续 A8W8 PTQ 量化损失。 - **GPTQ**。[GPTQ](https://arxiv.org/abs/2210.17323)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。 +- **FOEM**。[FOEM](https://arxiv.org/abs/2507.11017)是一种考虑一阶误差补偿的量化算法,相比GPTQ能够更进一步地减小量化误差。
llm @@ -320,6 +321,9 @@ python run_quantization.py ./config/llama/ptq_argument.json # GPTQ 量化启动命令参考 python run_quantization.py ./config/llama/gptq_argument.json +# FOEM Quantization Command Reference +python run_quantization.py ./config/llama/foem_argument.json + # W8A8C8(INT)量化启动命令参考 python run_quantization.py ./config/llama/ptq_c8_argument.json diff --git a/llm/config/llama/foem_argument.json b/llm/config/llama/foem_argument.json new file mode 100644 index 000000000000..16b83e43266c --- /dev/null +++ b/llm/config/llama/foem_argument.json @@ -0,0 +1,18 @@ +{ + "model_name_or_path": "meta-llama/Meta-Llama-3-8B", + "per_device_train_batch_size": 8, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "src_length": 1024, + "max_length": 2048, + "bf16": false, + "fp16_opt_level": "O0", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/gptq_ckpts", + "do_eval": true, + "eval_with_do_generation": false, + "do_gptq": true, + "do_foem": true, + "unified_checkpoint": true, + "gptq_step": 8 + } \ No newline at end of file diff --git a/llm/config/llama/irqlora_argument.json b/llm/config/llama/irqlora_argument.json new file mode 100644 index 000000000000..63526eb450af --- /dev/null +++ b/llm/config/llama/irqlora_argument.json @@ -0,0 +1,35 @@ +{ + "model_name_or_path": "facebook/llama-7b", + "dataset_name_or_path": "./data", + "output_dir": "./checkpoints/ir_qlora_ckpts", + "per_device_train_batch_size": 4, + "gradient_accumulation_steps": 4, + "per_device_eval_batch_size": 8, + "eval_accumulation_steps":16, + "num_train_epochs": 3, + "learning_rate": 3e-04, + "warmup_steps": 30, + "logging_steps": 1, + "evaluation_strategy": "epoch", + "save_strategy": "epoch", + "src_length": 1024, + "max_length": 2048, + "bf16": true, + "fp16_opt_level": "O2", + "do_train": true, + "do_eval": true, + "disable_tqdm": true, + "load_best_model_at_end": true, + "eval_with_do_generation": false, + "metric_for_best_model": "accuracy", + "recompute": true, + "save_total_limit": 1, + "tensor_parallel_degree": 1, + "pipeline_parallel_degree": 1, + "lora": true, + "zero_padding": false, + "use_flash_attention": false, + "unified_checkpoint": true, + "weight_quantize_algo": "nf4", + "ir_qlora": true +} \ No newline at end of file diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md index 5cc07822f1b7..f64fcf4bba1f 100644 --- a/llm/docs/finetune.md +++ b/llm/docs/finetune.md @@ -29,6 +29,8 @@ - QLoRA:量化感知低秩适配(Quantized Low-Rank Adaptation)与标准 LoRA 相比,它可额外减少多达33%的内存使用,使其在 GPU 内存受限的情况下尤为有用。QLoRA 通常比普通 LoRA 多花费约20%的时间,但其显著的内存节省使其在 GPU 内存有限的情况下成为唯一可行的选择。 +- IR-QLoRA:量化感知低秩适配(Accurate LoRA-Finetuning Quantization of LLMs via Information Retention, ICML 2024 Oral)与标准 QLoRA 相比,通过准确的信息保留进一步地提高了精度。 + ## 3. 快速开始 @@ -81,6 +83,9 @@ python run_finetune.py ./config/llama/lora_argument.json # 单卡QLoRA python run_finetune.py ./config/llama/qlora_argument.json +# 单卡IR-QLoRA +python run_finetune.py ./config/llama/irqlora_argument.json + # 多卡LoRA python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" run_finetune.py ./config/llama/lora_argument.json diff --git a/llm/docs/quantization.md b/llm/docs/quantization.md index 322e4d9b582e..b0028e338f85 100644 --- a/llm/docs/quantization.md +++ b/llm/docs/quantization.md @@ -8,6 +8,7 @@ 新增 PieceWiseSearch 参数搜索算法并将算法扩展至**所有线性层**,对模型权重和激活分布进行调整,减少后续 A8W8 PTQ 量化损失。 - **GPTQ**。[GPTQ](https://arxiv.org/abs/2210.17323)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。 - **AWQ**。[AWQ](https://arxiv.org/abs/2306.00978)是业界主流的权重量化算法,可以将大模型权重进行4位整数无损量化,提高模型推理速度。 +- **FOEM**。[FOEM](https://arxiv.org/abs/2507.11017)是一种考虑一阶误差补偿的量化算法,相比GPTQ能够更进一步地减小量化误差。
llm @@ -94,7 +95,13 @@ python run_quantization.py ./config/llama/ptq_c8_argument.json python run_quantization.py ./config/llama/fp8_ptq_argument.json ``` -### 2.8 量化参数介绍 +### 2.8 FOEM 量化 + +```shell +python run_quantization.py ./config/llama/foem_argument.json +``` + +### 2.9 量化参数介绍   量化参数(QuantArgument) @@ -130,6 +137,7 @@ python run_quantization.py ./config/llama/fp8_ptq_argument.json - `load_quant_model`: 是否加载量化模型,默认为 False。用于验证量化后的模型效果, 若设为 True,则从 output_dir 中加载权重。启动该过程需要设`do_ptq`为 False。如果量化时使用了 smooth 或 shift,加载时需要保持相同的配置(shift_step/search_step 可设为8)。注意,当前该函数只支持 pdparams 格式加载,若要使用该功能,设置`"unified_checkpoint": false`。 - `skip_list_names`: 需要量化跳过的层名称列表,默认为空列表。可以使用层名的部分字符串作为匹配,如['down_proj']表示跳过所有 ffn2层。 - `do_gptq`: 是否进行 GPTQ 量化,GPTQ 对模型进行 WINT4量化,相比于普通 PTQ 量化精度更高,量化时间较长。默认为 False。 +- `do_foem`: 是否进行 FOEM 量化,FOEM 对模型进行 WINT4量化,相比于普通 GPTQ 量化精度更高,量化时间较长。默认为 False。 - `gptq_step`: GPTQ 量化步数,也即模型前向次数,默认为8。 - `do_awq`: 是否进行 AWQ 量化,AWQ 对模型进行 WINT4量化,相比于普通 PTQ 量化精度更高。默认为 False。 - `auto_clip`: AWQ 时是否进行自动搜索截断值并对模型权重进行截断操作,截断操作有利于量化模型精度,但搜索速度较慢。默认为 False。 diff --git a/llm/run_finetune.py b/llm/run_finetune.py index 31427a516f2d..3e340f3c7012 100644 --- a/llm/run_finetune.py +++ b/llm/run_finetune.py @@ -591,6 +591,12 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config, else: model = LoRAModel.from_pretrained(model=model, lora_path=model_args.lora_path) + if model_args.ir_qlora: + from paddlenlp.quantization.irqlora_utils import get_my_model + model2 = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path) + model = get_my_model(model, model2) + del model2 + model.print_trainable_parameters() if model_args.lokr: diff --git a/llm/utils/quant.py b/llm/utils/quant.py index c2459994ffb3..da26370a5d3e 100644 --- a/llm/utils/quant.py +++ b/llm/utils/quant.py @@ -481,7 +481,7 @@ def apply_gptq(quant_args, trainer, ptq_dataloader): description="GPTQ", max_eval_iters=quant_args.gptq_step, ) - cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True) + cur_quant_layer.fasterquant(percdamp=0.1, groupsize=-1, actorder=True, do_foem=quant_args.do_foem, foem_beta=quant_args.foem_beta) del cur_quant_layer setattr(parent_layer, sub_name, cur_layer) logger.info("***** GPTQ done *****") diff --git a/paddlenlp/quantization/irqlora_utils.py b/paddlenlp/quantization/irqlora_utils.py new file mode 100644 index 000000000000..e507987692c7 --- /dev/null +++ b/paddlenlp/quantization/irqlora_utils.py @@ -0,0 +1,166 @@ +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import operator +import numpy as np +from paddlenlp.peft.lora.lora_quantization_layers import QuantizationLoRALinear, QuantizationLoRABaseLinear +from functools import reduce # Required in Python 3 +from scipy.stats import norm +from paddleslim.lc.quantizers.quant_func import create_dynamic_map +from paddleslim.lc.layers.linear import Linear4bit +from paddlenlp.quantization.qlora import qlora_weight_dequantize, qlora_weight_quantize + +cache_folder_path = '' +module_num = 0 +sigma = 1 / norm.ppf(paddle.linspace(0.9677083, 0.5, 9)[:-1]).tolist()[0] + +def get_my_model(model, model_fp, blocksize2=256, tau_range=0.1, tau_n=25): + model.model = _replace_with_ours_lora_4bit_linear(model.model, model_fp=model_fp, blocksize2=blocksize2, tau_range=tau_range, tau_n=tau_n) + return model + +def prod(iterable): + return reduce(operator.mul, iterable, 1) + +normal_map_fp8 = create_dynamic_map() +normal_map_fp8 = paddle.to_tensor(normal_map_fp8) +def quantize_tensor(X, L, idx=False): + X_shape = X.shape + X_expanded = X.reshape([-1, 1]) + L_reshaped = L.reshape([1, -1]) + abs_diff = paddle.abs(X_expanded - L_reshaped) + min_index = paddle.argmin(abs_diff, axis=-1) + min_index = min_index.cast("uint8").reshape(X_shape) + return min_index + +def dequantize_tensor(X, L): + return paddle.index_select(L, axis=0, index=paddle.to_tensor(X, dtype=paddle.int32).reshape([-1])).reshape(X.shape) + +@paddle.no_grad() +def nf4_quant(weight, weight_shape, tau, state, quant_algo): + weight = weight.reshape([-1, 256, 64]) + tau = tau.reshape([-1, 256, 1]) + _weight = (weight - tau).reshape(weight_shape) + nf4_weight = qlora_weight_quantize(_weight.cuda(), quant_algo, double_quant=True) + tau2 = tau.abs().max(axis=1, keepdim=True)[0] + tau1 = quantize_tensor(tau / tau2, normal_map_fp8) + return nf4_weight, tau1.reshape([-1, 256]), tau2.reshape([-1, 1]) + +@paddle.no_grad() +def evaluate_entropy(weight_int8, blocksize): + _weight_int8 = weight_int8.reshape([-1, 1]) + weight_nf4 = paddle.concat((_weight_int8//16, _weight_int8 & paddle.to_tensor(15).cast('uint8')), 1).reshape([1, -1, blocksize]) + weight_nf4_repeat = weight_nf4.cast('int32').tile([16, 1, 1]) + values = paddle.to_tensor(list(range(16))).reshape([16, 1, 1]) + freqs = (weight_nf4_repeat==values.cast('int32')).sum(axis=-1, keepdim=True) / blocksize + entropy = -freqs * paddle.log2(freqs) + entropy = paddle.where(paddle.isnan(entropy), 0., entropy) + entropy = entropy.sum(axis=0) + return entropy + +@paddle.no_grad() +def search(fp_weight: paddle.Tensor, fp_weight_shape, state, quant_algo, tau_range=0.1, tau_n=25, blocksize=64, blocksize2=256): + fp_weight = fp_weight.reshape([-1, blocksize2, blocksize]) + tau0 = fp_weight.median(2, keepdim=True)[0] # [-1, 256, 1] + absmax = (fp_weight - tau0).abs().max(2, keepdim=True)[0] + + entropy_max, factor_best = None, None + for factor in np.linspace(-tau_range*sigma, tau_range*sigma, tau_n*2+1): + tau = factor * absmax + tau0 + nf4_weight, _, _ = nf4_quant(fp_weight, fp_weight_shape, tau, state, quant_algo) + entropy = evaluate_entropy(nf4_weight['quant_weight'], blocksize) + + if entropy_max is None: + entropy_max = entropy + factor_best = paddle.full_like(entropy, factor) + else: + factor_best = paddle.where(entropy > entropy_max, factor, factor_best) + entropy_max = paddle.maximum(entropy_max, entropy) + + tau = factor_best.reshape([-1, 256, 1]).cast('float32') * absmax + tau0 + nf4_weight, tau1, tau2 = nf4_quant(fp_weight, fp_weight_shape, tau, state, quant_algo) + return nf4_weight, tau1, tau2 + +class IRQuantizationLoRALinear(QuantizationLoRALinear): + def __init__( + self, old_model, model_fp=None, blocksize2=256, tau_range=0.1, tau_n=51 + ): + for key, value in old_model.__dict__.items(): + setattr(self, key, value) + + fp_weight = model_fp.weight.data.contiguous().cpu() + fp_weight_shape = fp_weight.shape + + quant_weight, quant_dtype, quantization_config, weight_quantize_algo, dtype, quant_scale, quant_state = self.quant_weight, self.quant_dtype, self.quantization_config, self.weight_quantize_algo, self._dtype, self.quant_scale, None + + del model_fp + + nf4_weight, tau1, tau2 = search( + fp_weight=fp_weight, + fp_weight_shape=fp_weight_shape, + state=quant_state, + quant_algo=weight_quantize_algo, + tau_range=tau_range, tau_n=tau_n, + blocksize2=blocksize2 + ) + + self.quant_weight.data = nf4_weight['quant_weight'] + self.qquant_scale = nf4_weight['qquant_scale'] + self.double_quant_scale = nf4_weight['double_quant_scale'] + self.quant_sacle_offset = nf4_weight['quant_sacle_offset'] + self.tau_quant, self.tau_absmax = tau1, tau2 + self.fp_weight_shape = fp_weight_shape + + del fp_weight, nf4_weight + + a = paddle.zeros(1) + self.lora_default_A_scale = paddle.create_parameter(shape=a.shape, dtype=a.dtype, default_initializer=nn.initializer.Assign(a)) + self.lora_default_B_scale = paddle.create_parameter(shape=a.shape, dtype=a.dtype, default_initializer=nn.initializer.Assign(a)) + + def forward(self, x: paddle.Tensor): + + with paddle.no_grad(): + fp_B = qlora_weight_dequantize(self.quant_weight, self.weight_quantize_algo, (self.qquant_scale, self.double_quant_scale.cast('float32'), self.quant_sacle_offset.cast('float32')), double_quant=True) + tau = (dequantize_tensor(self.tau_quant, normal_map_fp8).reshape([-1, 256, 1]) * self.tau_absmax.reshape([-1, 1, 1])) + blocksize = paddle.prod(paddle.to_tensor(fp_B.shape)) // paddle.prod(paddle.to_tensor(tau.shape)) + fp_B = (fp_B.reshape([-1, blocksize.item()]) + tau.reshape([-1, 1])).reshape(self.fp_weight_shape) + + result = paddle.nn.functional.linear(x, fp_B, self.bias) + + if not self.disable_lora: + x1 = self.lora_dropout(x) + x2 = x1 @ self.lora_A + self.lora_default_A_scale * x.reshape([_ for _ in x.shape[:-1]] + [self.lora_A.shape[-1]] + [-1]).mean(axis=-1) + x3 = ((x2 @ self.lora_B).reshape([_ for _ in x2.shape] + [-1]) + (self.lora_default_B_scale * x2.unsqueeze(-1))).reshape([_ for _ in x2.shape[:-1]] + [-1]) + result += x3 * self.scaling + + return result + +def _replace_with_ours_lora_4bit_linear( + model: nn.Layer, current_key_name=None, model_fp=None, blocksize2=256, tau_range=0.5, tau_n=51 +): + assert isinstance(model_fp, nn.Layer) + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + current_key_name.append(name) + + if isinstance(module, QuantizationLoRALinear): + _modules = dict(model_fp.named_sublayers()) + + print(name) + new_layer = IRQuantizationLoRALinear(dict(model.named_sublayers())[name], model_fp=dict(model_fp.named_sublayers())[name], blocksize2=blocksize2, tau_range=tau_range, tau_n=tau_n) + setattr(model, name, new_layer) + + if len(list(module.children())) > 0: + _modules = dict(model_fp.named_sublayers()) + if name in _modules.keys(): + _ = _replace_with_ours_lora_4bit_linear( + module, + current_key_name, _modules[name], blocksize2, tau_range, tau_n + ) + else: + _ = _replace_with_ours_lora_4bit_linear( + module, + current_key_name, None, blocksize2, tau_range, tau_n + ) + current_key_name.pop(-1) + return model diff --git a/paddlenlp/trl/model_config.py b/paddlenlp/trl/model_config.py index 2e244d211158..3a165c900637 100644 --- a/paddlenlp/trl/model_config.py +++ b/paddlenlp/trl/model_config.py @@ -136,6 +136,7 @@ class ModelConfig: "help": "Block size for quant_scale of weight quant_scale(Only available for nf4 or fp4 quant_scale.)" }, ) + ir_qlora: bool = field(default=False, metadata={"help": "Whether to use IR-QLoRA"}) apply_hadamard: bool = field(default=False, metadata={"help": "Whether to apply hadamard"}) hadamard_block_size: int = field(default=32, metadata={"help": "hadamard block size"}) quant_input_grad: bool = field(default=False, metadata={"help": "Whether to quantize input grad"}) diff --git a/paddlenlp/trl/quant_config.py b/paddlenlp/trl/quant_config.py index a21481f70107..a66e0b405c0f 100644 --- a/paddlenlp/trl/quant_config.py +++ b/paddlenlp/trl/quant_config.py @@ -104,6 +104,9 @@ class QuantConfig: # GPTQ related parameters do_gptq: bool = field(default=False, metadata={"help": "Whether to use GPTQ"}) gptq_step: int = field(default=8, metadata={"help": "Step for GPTQ"}) + # FOEM related parameters + do_foem: bool = field(default=False, metadata={"help": "Whether to use FOEM"}) + foem_beta: int = field(default=0.1, metadata={"help": "beta for FOEM"}) # AWQ related parameters, default for WINT4 do_awq: bool = field(default=False, metadata={"help": "Whether to use AWQ Search"})